mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-24 13:22:05 +00:00
Change structure of THP implementation [Part 1]
This commit is contained in:
parent
fe9167ffa2
commit
0927bbeb68
6
core/src/all_modules.py
generated
6
core/src/all_modules.py
generated
@ -201,14 +201,20 @@ trezor.wire.context
|
|||||||
import trezor.wire.context
|
import trezor.wire.context
|
||||||
trezor.wire.errors
|
trezor.wire.errors
|
||||||
import trezor.wire.errors
|
import trezor.wire.errors
|
||||||
|
trezor.wire.message_handler
|
||||||
|
import trezor.wire.message_handler
|
||||||
trezor.wire.protocol
|
trezor.wire.protocol
|
||||||
import trezor.wire.protocol
|
import trezor.wire.protocol
|
||||||
trezor.wire.protocol_common
|
trezor.wire.protocol_common
|
||||||
import trezor.wire.protocol_common
|
import trezor.wire.protocol_common
|
||||||
trezor.wire.thp.ack_handler
|
trezor.wire.thp.ack_handler
|
||||||
import trezor.wire.thp.ack_handler
|
import trezor.wire.thp.ack_handler
|
||||||
|
trezor.wire.thp.channel_context
|
||||||
|
import trezor.wire.thp.channel_context
|
||||||
trezor.wire.thp.checksum
|
trezor.wire.thp.checksum
|
||||||
import trezor.wire.thp.checksum
|
import trezor.wire.thp.checksum
|
||||||
|
trezor.wire.thp.session_context
|
||||||
|
import trezor.wire.thp.session_context
|
||||||
trezor.wire.thp.thp_messages
|
trezor.wire.thp.thp_messages
|
||||||
import trezor.wire.thp.thp_messages
|
import trezor.wire.thp.thp_messages
|
||||||
trezor.wire.thp.thp_session
|
trezor.wire.thp.thp_session
|
||||||
|
@ -43,7 +43,7 @@ if __debug__:
|
|||||||
|
|
||||||
layout_change_chan = loop.chan()
|
layout_change_chan = loop.chan()
|
||||||
|
|
||||||
DEBUG_CONTEXT: context.Context | None = None
|
DEBUG_CONTEXT: context.CodecContext | None = None
|
||||||
|
|
||||||
LAYOUT_WATCHER_NONE = 0
|
LAYOUT_WATCHER_NONE = 0
|
||||||
LAYOUT_WATCHER_STATE = 1
|
LAYOUT_WATCHER_STATE = 1
|
||||||
|
@ -27,11 +27,11 @@ from micropython import const
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from storage.cache_common import InvalidSessionError
|
from storage.cache_common import InvalidSessionError
|
||||||
from trezor import log, loop, protobuf, utils, workflow
|
from trezor import log, loop, protobuf, utils
|
||||||
from trezor.enums import FailureType
|
from trezor.enums import FailureType
|
||||||
from trezor.messages import Failure
|
from trezor.messages import Failure
|
||||||
from trezor.wire import codec_v1, context, protocol_common
|
from trezor.wire import codec_v1, context, message_handler, protocol_common, thp_v1
|
||||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
from trezor.wire.errors import DataError, Error
|
||||||
|
|
||||||
# Import all errors into namespace, so that `wire.Error` is available from
|
# Import all errors into namespace, so that `wire.Error` is available from
|
||||||
# other packages.
|
# other packages.
|
||||||
@ -88,113 +88,43 @@ if __debug__:
|
|||||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_single_message(
|
async def handle_thp_session(iface: WireInterface, is_debug_session: bool = False):
|
||||||
ctx: context.Context, msg: protocol_common.MessageWithId, use_workflow: bool
|
if __debug__ and is_debug_session:
|
||||||
) -> protocol_common.MessageWithId | None:
|
ctx_buffer = WIRE_BUFFER_DEBUG
|
||||||
"""Handle a message that was loaded from USB by the caller.
|
else:
|
||||||
|
ctx_buffer = WIRE_BUFFER
|
||||||
|
|
||||||
Find the appropriate handler, run it and write its result on the wire. In case
|
thp_v1.set_buffer(ctx_buffer)
|
||||||
a problem is encountered at any point, write the appropriate error on the wire.
|
|
||||||
|
|
||||||
If the workflow finished normally or with an error, the return value is None.
|
if __debug__ and is_debug_session:
|
||||||
|
import apps.debug
|
||||||
|
|
||||||
If an unexpected message had arrived on the wire while the workflow was processing,
|
print(apps.debug.DEBUG_CONTEXT) # TODO remove
|
||||||
the workflow is shut down with an `UnexpectedMessage` exception. This is not
|
|
||||||
considered an "error condition" to return over the wire -- instead the message
|
# TODO add debug context or smth to apps.debug
|
||||||
is processed as if starting a new workflow.
|
|
||||||
In such case, the `UnexpectedMessage` is caught and the message is returned
|
# Take a mark of modules that are imported at this point, so we can
|
||||||
to the caller. It will then be processed in the next iteration of the message loop.
|
# roll back and un-import any others.
|
||||||
"""
|
modules = utils.unimport_begin()
|
||||||
if __debug__:
|
|
||||||
|
while True:
|
||||||
try:
|
try:
|
||||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
await thp_v1.thp_main_loop(iface, is_debug_session)
|
||||||
except Exception:
|
|
||||||
msg_type = f"{msg.type} - unknown message type"
|
|
||||||
if ctx.session_id is not None:
|
|
||||||
sid = int.from_bytes(ctx.session_id, "big")
|
|
||||||
else:
|
|
||||||
sid = -1
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%s:%x receive: <%s>",
|
|
||||||
ctx.iface.iface_num(),
|
|
||||||
sid,
|
|
||||||
msg_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
res_msg: protobuf.MessageType | None = None
|
if not __debug__ or not is_debug_session:
|
||||||
|
# Unload modules imported by the workflow. Should not raise.
|
||||||
|
# This is not done for the debug session because the snapshot taken
|
||||||
|
# in a debug session would clear modules which are in use by the
|
||||||
|
# workflow running on wire.
|
||||||
|
utils.unimport_end(modules)
|
||||||
|
loop.clear()
|
||||||
|
return
|
||||||
|
|
||||||
# We need to find a handler for this message type. Should not raise.
|
except Exception as exc:
|
||||||
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
|
# Log and try again. The session handler can only exit explicitly via
|
||||||
|
# loop.clear() above.
|
||||||
if handler is None:
|
if __debug__:
|
||||||
# If no handler is found, we can skip decoding and directly
|
|
||||||
# respond with failure.
|
|
||||||
await ctx.write(unexpected_message())
|
|
||||||
return None
|
|
||||||
|
|
||||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
|
||||||
workflow.autolock_interrupts_workflow = False
|
|
||||||
|
|
||||||
# Here we make sure we always respond with a Failure response
|
|
||||||
# in case of any errors.
|
|
||||||
try:
|
|
||||||
# Find a protobuf.MessageType subclass that describes this
|
|
||||||
# message. Raises if the type is not found.
|
|
||||||
req_type = protobuf.type_for_wire(msg.type)
|
|
||||||
|
|
||||||
# Try to decode the message according to schema from
|
|
||||||
# `req_type`. Raises if the message is malformed.
|
|
||||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
|
||||||
|
|
||||||
# Create the handler task.
|
|
||||||
task = handler(req_msg)
|
|
||||||
|
|
||||||
# Run the workflow task. Workflow can do more on-the-wire
|
|
||||||
# communication inside, but it should eventually return a
|
|
||||||
# response message, or raise an exception (a rather common
|
|
||||||
# thing to do). Exceptions are handled in the code below.
|
|
||||||
if use_workflow:
|
|
||||||
# Spawn a workflow around the task. This ensures that concurrent
|
|
||||||
# workflows are shut down.
|
|
||||||
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
|
||||||
else:
|
|
||||||
# For debug messages, ignore workflow processing and just await
|
|
||||||
# results of the handler.
|
|
||||||
res_msg = await task
|
|
||||||
|
|
||||||
except context.UnexpectedMessage as exc:
|
|
||||||
# Workflow was trying to read a message from the wire, and
|
|
||||||
# something unexpected came in. See Context.read() for
|
|
||||||
# example, which expects some particular message and raises
|
|
||||||
# UnexpectedMessage if another one comes in.
|
|
||||||
# In order not to lose the message, we return it to the caller.
|
|
||||||
# TODO:
|
|
||||||
# We might handle only the few common cases here, like
|
|
||||||
# Initialize and Cancel.
|
|
||||||
return exc.msg
|
|
||||||
|
|
||||||
except BaseException as exc:
|
|
||||||
# Either:
|
|
||||||
# - the message had a type that has a registered handler, but does not have
|
|
||||||
# a protobuf class
|
|
||||||
# - the message was not valid protobuf
|
|
||||||
# - workflow raised some kind of an exception while running
|
|
||||||
# - something canceled the workflow from the outside
|
|
||||||
if __debug__:
|
|
||||||
if isinstance(exc, ActionCancelled):
|
|
||||||
log.debug(__name__, "cancelled: %s", exc.message)
|
|
||||||
elif isinstance(exc, loop.TaskClosed):
|
|
||||||
log.debug(__name__, "cancelled: loop task was closed")
|
|
||||||
else:
|
|
||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
res_msg = failure(exc)
|
|
||||||
|
|
||||||
if res_msg is not None:
|
|
||||||
# perform the write outside the big try-except block, so that usb write
|
|
||||||
# problem bubbles up
|
|
||||||
await ctx.write(res_msg)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_session(
|
async def handle_session(
|
||||||
@ -205,7 +135,7 @@ async def handle_session(
|
|||||||
else:
|
else:
|
||||||
ctx_buffer = WIRE_BUFFER
|
ctx_buffer = WIRE_BUFFER
|
||||||
|
|
||||||
ctx = context.Context(iface, ctx_buffer, session_id)
|
ctx = context.CodecContext(iface, ctx_buffer, session_id)
|
||||||
next_msg: protocol_common.MessageWithId | None = None
|
next_msg: protocol_common.MessageWithId | None = None
|
||||||
|
|
||||||
if __debug__ and is_debug_session:
|
if __debug__ and is_debug_session:
|
||||||
@ -235,10 +165,10 @@ async def handle_session(
|
|||||||
next_msg = None
|
next_msg = None
|
||||||
|
|
||||||
# Set ctx.session_id to the value msg.session_id
|
# Set ctx.session_id to the value msg.session_id
|
||||||
ctx.session_id = msg.session_id
|
ctx.channel_id = msg.session_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
next_msg = await _handle_single_message(
|
next_msg = await message_handler.handle_single_message(
|
||||||
ctx, msg, use_workflow=not is_debug_session
|
ctx, msg, use_workflow=not is_debug_session
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
|
|||||||
import trezor.wire.protocol as protocol
|
import trezor.wire.protocol as protocol
|
||||||
from trezor import log, loop, protobuf
|
from trezor import log, loop, protobuf
|
||||||
|
|
||||||
from .protocol_common import MessageWithId
|
from .protocol_common import Context, MessageWithId
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
@ -54,10 +54,10 @@ class UnexpectedMessage(Exception):
|
|||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class CodecContext(Context):
|
||||||
"""Wire context.
|
"""Wire context.
|
||||||
|
|
||||||
Represents USB communication inside a particular session on a particular interface
|
Represents USB communication inside a particular session (channel) on a particular interface
|
||||||
(i.e., wire, debug, single BT connection, etc.)
|
(i.e., wire, debug, single BT connection, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -65,11 +65,12 @@ class Context:
|
|||||||
self,
|
self,
|
||||||
iface: WireInterface,
|
iface: WireInterface,
|
||||||
buffer: bytearray,
|
buffer: bytearray,
|
||||||
session_id: bytes | None = None,
|
channel_id: bytes | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.iface = iface
|
self.iface = iface
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.session_id = session_id
|
self.channel_id = channel_id
|
||||||
|
super().__init__(iface, channel_id)
|
||||||
|
|
||||||
def read_from_wire(self) -> Awaitable[MessageWithId]:
|
def read_from_wire(self) -> Awaitable[MessageWithId]:
|
||||||
"""Read a whole message from the wire without parsing it."""
|
"""Read a whole message from the wire without parsing it."""
|
||||||
@ -99,8 +100,8 @@ class Context:
|
|||||||
to save on having to decode the type code into a protobuf class.
|
to save on having to decode the type code into a protobuf class.
|
||||||
"""
|
"""
|
||||||
if __debug__:
|
if __debug__:
|
||||||
if self.session_id is not None:
|
if self.channel_id is not None:
|
||||||
sid = int.from_bytes(self.session_id, "big")
|
sid = int.from_bytes(self.channel_id, "big")
|
||||||
else:
|
else:
|
||||||
sid = -1
|
sid = -1
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -126,8 +127,8 @@ class Context:
|
|||||||
expected_type = protobuf.type_for_wire(msg.type)
|
expected_type = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
if self.session_id is not None:
|
if self.channel_id is not None:
|
||||||
sid = int.from_bytes(self.session_id, "big")
|
sid = int.from_bytes(self.channel_id, "big")
|
||||||
else:
|
else:
|
||||||
sid = -1
|
sid = -1
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -146,8 +147,8 @@ class Context:
|
|||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
"""Write a message to the wire."""
|
"""Write a message to the wire."""
|
||||||
if __debug__:
|
if __debug__:
|
||||||
if self.session_id is not None:
|
if self.channel_id is not None:
|
||||||
sid = int.from_bytes(self.session_id, "big")
|
sid = int.from_bytes(self.channel_id, "big")
|
||||||
else:
|
else:
|
||||||
sid = -1
|
sid = -1
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -173,8 +174,8 @@ class Context:
|
|||||||
msg_size = protobuf.encode(buffer, msg)
|
msg_size = protobuf.encode(buffer, msg)
|
||||||
|
|
||||||
msg_session_id = None
|
msg_session_id = None
|
||||||
if self.session_id is not None:
|
if self.channel_id is not None:
|
||||||
msg_session_id = bytearray(self.session_id)
|
msg_session_id = bytearray(self.channel_id)
|
||||||
await protocol.write_message(
|
await protocol.write_message(
|
||||||
self.iface,
|
self.iface,
|
||||||
MessageWithId(
|
MessageWithId(
|
||||||
@ -185,7 +186,7 @@ class Context:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
CURRENT_CONTEXT: Context | None = None
|
CURRENT_CONTEXT: CodecContext | None = None
|
||||||
|
|
||||||
|
|
||||||
def wait(task: Awaitable[T]) -> Awaitable[T]:
|
def wait(task: Awaitable[T]) -> Awaitable[T]:
|
||||||
@ -250,7 +251,7 @@ async def maybe_call(
|
|||||||
await call(msg, expected_type)
|
await call(msg, expected_type)
|
||||||
|
|
||||||
|
|
||||||
def get_context() -> Context:
|
def get_context() -> CodecContext:
|
||||||
"""Get the current session context.
|
"""Get the current session context.
|
||||||
|
|
||||||
Can be needed in case the caller needs raw read and raw write capabilities, which
|
Can be needed in case the caller needs raw read and raw write capabilities, which
|
||||||
@ -264,7 +265,7 @@ def get_context() -> Context:
|
|||||||
return CURRENT_CONTEXT
|
return CURRENT_CONTEXT
|
||||||
|
|
||||||
|
|
||||||
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
def with_context(ctx: CodecContext, workflow: loop.Task) -> Generator:
|
||||||
"""Run a workflow in a particular context.
|
"""Run a workflow in a particular context.
|
||||||
|
|
||||||
Stores the context in a closure and installs it into the global variable every time
|
Stores the context in a closure and installs it into the global variable every time
|
||||||
|
195
core/src/trezor/wire/message_handler.py
Normal file
195
core/src/trezor/wire/message_handler.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from storage.cache_common import InvalidSessionError
|
||||||
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
|
from trezor.enums import FailureType
|
||||||
|
from trezor.messages import Failure
|
||||||
|
from trezor.wire import context, protocol_common
|
||||||
|
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||||
|
|
||||||
|
# Import all errors into namespace, so that `wire.Error` is available from
|
||||||
|
# other packages.
|
||||||
|
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezorio import WireInterface
|
||||||
|
from typing import Any, Callable, Container, Coroutine, TypeVar
|
||||||
|
|
||||||
|
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||||
|
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||||
|
Handler = Callable[[Msg], HandlerTask]
|
||||||
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
|
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||||
|
EXPERIMENTAL_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_protobuf_load(
|
||||||
|
buffer: bytes,
|
||||||
|
expected_type: type[LoadedMessageType],
|
||||||
|
) -> LoadedMessageType:
|
||||||
|
try:
|
||||||
|
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||||
|
if __debug__ and utils.EMULATOR:
|
||||||
|
log.debug(
|
||||||
|
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||||
|
)
|
||||||
|
return msg
|
||||||
|
except Exception as e:
|
||||||
|
if __debug__:
|
||||||
|
log.exception(__name__, e)
|
||||||
|
if e.args:
|
||||||
|
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||||
|
else:
|
||||||
|
raise DataError("Failed to decode message")
|
||||||
|
|
||||||
|
|
||||||
|
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||||
|
|
||||||
|
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
||||||
|
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_single_message(
|
||||||
|
ctx: context.CodecContext, msg: protocol_common.MessageWithId, use_workflow: bool
|
||||||
|
) -> protocol_common.MessageWithId | None:
|
||||||
|
"""Handle a message that was loaded from USB by the caller.
|
||||||
|
|
||||||
|
Find the appropriate handler, run it and write its result on the wire. In case
|
||||||
|
a problem is encountered at any point, write the appropriate error on the wire.
|
||||||
|
|
||||||
|
If the workflow finished normally or with an error, the return value is None.
|
||||||
|
|
||||||
|
If an unexpected message had arrived on the wire while the workflow was processing,
|
||||||
|
the workflow is shut down with an `UnexpectedMessage` exception. This is not
|
||||||
|
considered an "error condition" to return over the wire -- instead the message
|
||||||
|
is processed as if starting a new workflow.
|
||||||
|
In such case, the `UnexpectedMessage` is caught and the message is returned
|
||||||
|
to the caller. It will then be processed in the next iteration of the message loop.
|
||||||
|
"""
|
||||||
|
if __debug__:
|
||||||
|
try:
|
||||||
|
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||||
|
except Exception:
|
||||||
|
msg_type = f"{msg.type} - unknown message type"
|
||||||
|
if ctx.channel_id is not None:
|
||||||
|
sid = int.from_bytes(ctx.channel_id, "big")
|
||||||
|
else:
|
||||||
|
sid = -1
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%s:%x receive: <%s>",
|
||||||
|
ctx.iface.iface_num(),
|
||||||
|
sid,
|
||||||
|
msg_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
res_msg: protobuf.MessageType | None = None
|
||||||
|
|
||||||
|
# We need to find a handler for this message type. Should not raise.
|
||||||
|
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
|
||||||
|
|
||||||
|
if handler is None:
|
||||||
|
# If no handler is found, we can skip decoding and directly
|
||||||
|
# respond with failure.
|
||||||
|
await ctx.write(unexpected_message())
|
||||||
|
return None
|
||||||
|
|
||||||
|
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||||
|
workflow.autolock_interrupts_workflow = False
|
||||||
|
|
||||||
|
# Here we make sure we always respond with a Failure response
|
||||||
|
# in case of any errors.
|
||||||
|
try:
|
||||||
|
# Find a protobuf.MessageType subclass that describes this
|
||||||
|
# message. Raises if the type is not found.
|
||||||
|
req_type = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
|
# Try to decode the message according to schema from
|
||||||
|
# `req_type`. Raises if the message is malformed.
|
||||||
|
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||||
|
|
||||||
|
# Create the handler task.
|
||||||
|
task = handler(req_msg)
|
||||||
|
|
||||||
|
# Run the workflow task. Workflow can do more on-the-wire
|
||||||
|
# communication inside, but it should eventually return a
|
||||||
|
# response message, or raise an exception (a rather common
|
||||||
|
# thing to do). Exceptions are handled in the code below.
|
||||||
|
if use_workflow:
|
||||||
|
# Spawn a workflow around the task. This ensures that concurrent
|
||||||
|
# workflows are shut down.
|
||||||
|
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||||
|
else:
|
||||||
|
# For debug messages, ignore workflow processing and just await
|
||||||
|
# results of the handler.
|
||||||
|
res_msg = await task
|
||||||
|
|
||||||
|
except context.UnexpectedMessage as exc:
|
||||||
|
# Workflow was trying to read a message from the wire, and
|
||||||
|
# something unexpected came in. See Context.read() for
|
||||||
|
# example, which expects some particular message and raises
|
||||||
|
# UnexpectedMessage if another one comes in.
|
||||||
|
# In order not to lose the message, we return it to the caller.
|
||||||
|
# TODO:
|
||||||
|
# We might handle only the few common cases here, like
|
||||||
|
# Initialize and Cancel.
|
||||||
|
return exc.msg
|
||||||
|
|
||||||
|
except BaseException as exc:
|
||||||
|
# Either:
|
||||||
|
# - the message had a type that has a registered handler, but does not have
|
||||||
|
# a protobuf class
|
||||||
|
# - the message was not valid protobuf
|
||||||
|
# - workflow raised some kind of an exception while running
|
||||||
|
# - something canceled the workflow from the outside
|
||||||
|
if __debug__:
|
||||||
|
if isinstance(exc, ActionCancelled):
|
||||||
|
log.debug(__name__, "cancelled: %s", exc.message)
|
||||||
|
elif isinstance(exc, loop.TaskClosed):
|
||||||
|
log.debug(__name__, "cancelled: loop task was closed")
|
||||||
|
else:
|
||||||
|
log.exception(__name__, exc)
|
||||||
|
res_msg = failure(exc)
|
||||||
|
|
||||||
|
if res_msg is not None:
|
||||||
|
# perform the write outside the big try-except block, so that usb write
|
||||||
|
# problem bubbles up
|
||||||
|
await ctx.write(res_msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
|
||||||
|
"""Placeholder handler lookup before a proper one is registered."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
find_handler = _find_handler_placeholder
|
||||||
|
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||||
|
|
||||||
|
|
||||||
|
def failure(exc: BaseException) -> Failure:
|
||||||
|
if isinstance(exc, Error):
|
||||||
|
return Failure(code=exc.code, message=exc.message)
|
||||||
|
elif isinstance(exc, loop.TaskClosed):
|
||||||
|
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||||
|
elif isinstance(exc, InvalidSessionError):
|
||||||
|
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||||
|
else:
|
||||||
|
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||||
|
# change the `if __debug__` to `if True` to get the full error message.
|
||||||
|
if __debug__:
|
||||||
|
message = str(exc)
|
||||||
|
else:
|
||||||
|
message = "Firmware error"
|
||||||
|
return Failure(code=FailureType.FirmwareError, message=message)
|
||||||
|
|
||||||
|
|
||||||
|
def unexpected_message() -> Failure:
|
||||||
|
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
@ -1,3 +1,6 @@
|
|||||||
|
from trezor import protobuf
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -24,3 +27,11 @@ class MessageWithId(Message):
|
|||||||
|
|
||||||
class WireError(Exception):
|
class WireError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Context:
|
||||||
|
def __init__(self, iface, channel_id) -> None:
|
||||||
|
self.iface = iface
|
||||||
|
self.channel_id = channel_id
|
||||||
|
|
||||||
|
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||||
|
106
core/src/trezor/wire/thp/channel_context.py
Normal file
106
core/src/trezor/wire/thp/channel_context.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import ustruct
|
||||||
|
from micropython import const
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from storage.cache_thp import SessionThpCache
|
||||||
|
from trezor import loop, protobuf, utils
|
||||||
|
|
||||||
|
from ..protocol_common import Context
|
||||||
|
from . import thp_session
|
||||||
|
from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT
|
||||||
|
|
||||||
|
# from .thp_session import SessionState, ThpError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezorio import WireInterface
|
||||||
|
|
||||||
|
_INIT_DATA_OFFSET = const(5)
|
||||||
|
_CONT_DATA_OFFSET = const(3)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelContext(Context):
|
||||||
|
def __init__(
|
||||||
|
self, iface: WireInterface, channel_id: int, session_data: SessionThpCache
|
||||||
|
) -> None:
|
||||||
|
super().__init__(iface, channel_id)
|
||||||
|
self.session_data = session_data
|
||||||
|
self.buffer: utils.BufferType
|
||||||
|
self.waiting_for_ack_timeout: loop.Task | None
|
||||||
|
self.is_cont_packet_expected: bool = False
|
||||||
|
self.expected_payload_length: int = 0
|
||||||
|
self.bytes_read = 0
|
||||||
|
|
||||||
|
# ACCESS TO SESSION_DATA
|
||||||
|
|
||||||
|
def get_management_session_state(self):
|
||||||
|
return thp_session.get_state(self.session_data)
|
||||||
|
|
||||||
|
# CALLED BY THP_MAIN_LOOP
|
||||||
|
|
||||||
|
async def receive_packet(self, packet: utils.BufferType):
|
||||||
|
ctrl_byte = packet[0]
|
||||||
|
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||||
|
await self._handle_cont_packet(packet)
|
||||||
|
else:
|
||||||
|
await self._handle_init_packet(packet)
|
||||||
|
|
||||||
|
async def _handle_init_packet(self, packet):
|
||||||
|
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||||
|
packet_payload = packet[5:]
|
||||||
|
|
||||||
|
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
||||||
|
packet_payload = self._decode(packet_payload)
|
||||||
|
|
||||||
|
# session_id = packet_payload[0] # TODO handle handshake differently
|
||||||
|
self.expected_payload_length = payload_length
|
||||||
|
self.bytes_read = 0
|
||||||
|
|
||||||
|
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
|
||||||
|
# TODO Set/Provide different buffer for management session
|
||||||
|
|
||||||
|
if self.expected_payload_length == self.bytes_read:
|
||||||
|
self._finish_message()
|
||||||
|
else:
|
||||||
|
self.is_cont_packet_expected = True
|
||||||
|
|
||||||
|
async def _handle_cont_packet(self, packet):
|
||||||
|
if not self.is_cont_packet_expected:
|
||||||
|
return # Continuation packet is not expected, ignoring
|
||||||
|
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
|
||||||
|
|
||||||
|
def _decode(self, payload) -> bytes:
|
||||||
|
return payload # TODO add decryption process
|
||||||
|
|
||||||
|
async def _buffer_packet_data(
|
||||||
|
self, payload_buffer, packet: utils.BufferType, offset
|
||||||
|
):
|
||||||
|
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||||
|
|
||||||
|
def _finish_message(self):
|
||||||
|
# TODO Provide loaded message to SessionContext or handle it with this ChannelContext
|
||||||
|
self.bytes_read = 0
|
||||||
|
self.expected_payload_length = 0
|
||||||
|
self.is_cont_packet_expected = False
|
||||||
|
|
||||||
|
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||||
|
|
||||||
|
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||||
|
pass
|
||||||
|
# TODO protocol.write(self.iface, self.channel_id, session_id, msg)
|
||||||
|
|
||||||
|
def create_new_session(
|
||||||
|
self,
|
||||||
|
passphrase="",
|
||||||
|
) -> None: # TODO change it to output session data
|
||||||
|
pass
|
||||||
|
# TODO check, wheter a session with this passphrase already exists
|
||||||
|
# if not, create a new session with this passphrase
|
||||||
|
# if yes, what TODO TODO ???
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
|
||||||
|
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
|
14
core/src/trezor/wire/thp/session_context.py
Normal file
14
core/src/trezor/wire/thp/session_context.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from trezor import protobuf
|
||||||
|
|
||||||
|
from ..context import Context
|
||||||
|
from .channel_context import ChannelContext
|
||||||
|
|
||||||
|
|
||||||
|
class SessionContext(Context):
|
||||||
|
def __init__(self, channel_context: ChannelContext, session_id: int) -> None:
|
||||||
|
super().__init__(channel_context.iface, channel_context.channel_id)
|
||||||
|
self.channel_context = channel_context
|
||||||
|
self.session_id = session_id
|
||||||
|
|
||||||
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
|
return await self.channel_context.write(msg, self.session_id)
|
@ -5,6 +5,7 @@ from storage.cache_thp import BROADCAST_CHANNEL_ID
|
|||||||
from ..protocol_common import Message
|
from ..protocol_common import Message
|
||||||
|
|
||||||
CONTINUATION_PACKET = 0x80
|
CONTINUATION_PACKET = 0x80
|
||||||
|
ENCRYPTED_TRANSPORT = 0x02
|
||||||
_ERROR = 0x41
|
_ERROR = 0x41
|
||||||
_CHANNEL_ALLOCATION_RES = 0x40
|
_CHANNEL_ALLOCATION_RES = 0x40
|
||||||
|
|
||||||
|
@ -8,8 +8,14 @@ from trezor import io, loop, utils
|
|||||||
from .protocol_common import MessageWithId
|
from .protocol_common import MessageWithId
|
||||||
from .thp import ack_handler, checksum, thp_messages
|
from .thp import ack_handler, checksum, thp_messages
|
||||||
from .thp import thp_session as THP
|
from .thp import thp_session as THP
|
||||||
|
from .thp.channel_context import ChannelContext
|
||||||
from .thp.checksum import CHECKSUM_LENGTH
|
from .thp.checksum import CHECKSUM_LENGTH
|
||||||
from .thp.thp_messages import CONTINUATION_PACKET, InitHeader, InterruptingInitPacket
|
from .thp.thp_messages import (
|
||||||
|
CONTINUATION_PACKET,
|
||||||
|
ENCRYPTED_TRANSPORT,
|
||||||
|
InitHeader,
|
||||||
|
InterruptingInitPacket,
|
||||||
|
)
|
||||||
from .thp.thp_session import SessionState, ThpError
|
from .thp.thp_session import SessionState, ThpError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -21,12 +27,14 @@ _CHANNEL_ALLOCATION_REQ = 0x40
|
|||||||
_ACK_MESSAGE = 0x20
|
_ACK_MESSAGE = 0x20
|
||||||
_HANDSHAKE_INIT = 0x00
|
_HANDSHAKE_INIT = 0x00
|
||||||
_PLAINTEXT = 0x01
|
_PLAINTEXT = 0x01
|
||||||
ENCRYPTED_TRANSPORT = 0x02
|
|
||||||
|
|
||||||
_REPORT_LENGTH = const(64)
|
_REPORT_LENGTH = const(64)
|
||||||
_REPORT_INIT_DATA_OFFSET = const(5)
|
_REPORT_INIT_DATA_OFFSET = const(5)
|
||||||
_REPORT_CONT_DATA_OFFSET = const(3)
|
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||||
|
|
||||||
|
_BUFFER: bytearray
|
||||||
|
_BUFFER_LOCK = None
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
||||||
msg = await read_message_or_init_packet(iface, buffer)
|
msg = await read_message_or_init_packet(iface, buffer)
|
||||||
@ -38,6 +46,37 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def set_buffer(buffer):
|
||||||
|
_BUFFER = buffer
|
||||||
|
print(_BUFFER) # TODO remove
|
||||||
|
|
||||||
|
|
||||||
|
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||||
|
|
||||||
|
CHANNELS: dict[int, ChannelContext] = {}
|
||||||
|
# TODO load cached channels/sessions
|
||||||
|
|
||||||
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
packet = await read
|
||||||
|
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||||
|
|
||||||
|
if cid == BROADCAST_CHANNEL_ID:
|
||||||
|
await _handle_broadcast(iface, ctrl_byte, packet)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cid in CHANNELS:
|
||||||
|
channel = CHANNELS[cid]
|
||||||
|
if channel is None:
|
||||||
|
raise ThpError("Invalid state of a channel")
|
||||||
|
if channel.get_management_session_state != SessionState.UNALLOCATED:
|
||||||
|
await channel.receive_packet(packet)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await _handle_unallocated(iface, cid)
|
||||||
|
|
||||||
|
|
||||||
async def read_message_or_init_packet(
|
async def read_message_or_init_packet(
|
||||||
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
|
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
|
||||||
) -> MessageWithId | InterruptingInitPacket:
|
) -> MessageWithId | InterruptingInitPacket:
|
||||||
@ -50,10 +89,10 @@ async def read_message_or_init_packet(
|
|||||||
raise ThpError("Reading failed unexpectedly, report is None.")
|
raise ThpError("Reading failed unexpectedly, report is None.")
|
||||||
|
|
||||||
# Channel multiplexing
|
# Channel multiplexing
|
||||||
ctrl_byte, cid, payload_length = ustruct.unpack(">BHH", report)
|
ctrl_byte, cid = ustruct.unpack(">BH", report)
|
||||||
|
|
||||||
if cid == BROADCAST_CHANNEL_ID:
|
if cid == BROADCAST_CHANNEL_ID:
|
||||||
await _handle_broadcast(iface, ctrl_byte, report) # TODO await
|
await _handle_broadcast(iface, ctrl_byte, report)
|
||||||
report = None
|
report = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -64,7 +103,7 @@ async def read_message_or_init_packet(
|
|||||||
# continuation packet is not expected - ignore
|
# continuation packet is not expected - ignore
|
||||||
report = None
|
report = None
|
||||||
continue
|
continue
|
||||||
|
payload_length = ustruct.unpack(">H", report[3:])[0]
|
||||||
payload = _get_buffer_for_payload(payload_length, buffer)
|
payload = _get_buffer_for_payload(payload_length, buffer)
|
||||||
header = InitHeader(ctrl_byte, cid, payload_length)
|
header = InitHeader(ctrl_byte, cid, payload_length)
|
||||||
|
|
||||||
@ -255,15 +294,15 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_broadcast(
|
async def _handle_broadcast(
|
||||||
iface: WireInterface, ctrl_byte, report
|
iface: WireInterface, ctrl_byte, packet
|
||||||
) -> MessageWithId | None:
|
) -> MessageWithId | None:
|
||||||
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
||||||
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
||||||
|
|
||||||
length, nonce = ustruct.unpack(">H8s", report[3:])
|
length, nonce = ustruct.unpack(">H8s", packet[3:])
|
||||||
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
|
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
|
||||||
|
|
||||||
payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
|
payload = _get_buffer_for_payload(length, packet[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
|
||||||
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||||
raise ThpError("Checksum is not valid")
|
raise ThpError("Checksum is not valid")
|
||||||
|
|
||||||
@ -274,8 +313,8 @@ async def _handle_broadcast(
|
|||||||
response_header = InitHeader.get_channel_allocation_response_header(
|
response_header = InitHeader.get_channel_allocation_response_header(
|
||||||
len(response_data) + CHECKSUM_LENGTH,
|
len(response_data) + CHECKSUM_LENGTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
chksum = checksum.compute(response_header.to_bytes() + response_data)
|
chksum = checksum.compute(response_header.to_bytes() + response_data)
|
||||||
|
|
||||||
await write_to_wire(iface, response_header, response_data + chksum)
|
await write_to_wire(iface, response_header, response_data + chksum)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user