diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 9b0efeb1d..cf40d9672 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -177,20 +177,13 @@ def get_features() -> Features: # handle_Initialize should not be used with THP to start a new session -async def handle_Initialize( - msg: Initialize, message_session_id: bytearray | None = None -) -> Features: - if message_session_id is None and utils.USE_THP: +async def handle_Initialize(msg: Initialize) -> Features: + if utils.USE_THP: raise ValueError("With THP enabled, a session id must be provided in args") - if utils.USE_THP: - session_id = storage_thp_cache.start_existing_session(msg.session_id) - else: - session_id = storage_cache.start_session(msg.session_id) + session_id = storage_cache.start_session(msg.session_id) if not utils.BITCOIN_ONLY: - # TODO this block should be changed in THP - derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO) have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) @@ -212,7 +205,7 @@ async def handle_Initialize( ) features = get_features() - features.session_id = session_id # not important in THP + features.session_id = session_id return features diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index d9d39bd23..63f505dc0 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -56,7 +56,7 @@ class SessionThpCache(DataCache): # TODO implement, this is just copied Session def clear(self) -> None: super().clear() - self.state = 0 # Set state to UNALLOCATED + self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED self.last_usage = 0 self.session_id[:] = b"" @@ -175,6 +175,7 @@ def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: _session_usage_counter += 1 _SESSIONS[new_auth_session_index].last_usage = _session_usage_counter + return _SESSIONS[new_auth_session_index] def get_least_recently_used_authetnicated_session_index() -> int: @@ -216,7 +217,7 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete return new_session_id -def start_existing_session(session_id: bytearray) -> bytes: +def start_existing_session(session_id: bytes) -> bytes: if session_id is None: raise ValueError("session_id cannot be None") if get_active_session_id() == session_id: diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 67e0f8fcc..2f5edb05f 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -47,7 +47,6 @@ if TYPE_CHECKING: Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] Handler = Callable[[Msg], HandlerTask] - HandlerWithSessionId = Callable[[Msg, bytes | None], HandlerTask] LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) @@ -57,7 +56,9 @@ EXPERIMENTAL_ENABLED = False def setup(iface: WireInterface, is_debug_session: bool = False) -> None: """Initialize the wire stack on passed USB interface.""" - loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session)) + loop.schedule( + handle_session(iface, codec_v1.SESSION_ID.to_bytes(4, "big"), is_debug_session) + ) def wrap_protobuf_load( @@ -145,11 +146,7 @@ async def _handle_single_message( req_msg = wrap_protobuf_load(msg.data, req_type) # Create the handler task. - if msg.type is MT.Initialize: - # Special case for handle_initialize to have access to the verified session_id - task = handler(req_msg, ctx.session_id) - else: - task = handler(req_msg) + task = handler(req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a @@ -268,9 +265,7 @@ async def handle_session( log.exception(__name__, exc) -def _find_handler_placeholder( - iface: WireInterface, msg_type: int -) -> Handler | HandlerWithSessionId | None: +def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None: """Placeholder handler lookup before a proper one is registered.""" return None diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index e5a40bab7..bd4be6052 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING from trezor import log, loop, protobuf -from .protocol import WireProtocol +import trezor.wire.protocol from .protocol_common import Message if TYPE_CHECKING: @@ -69,11 +69,11 @@ class Context: ) -> None: self.iface = iface self.buffer = buffer - self.session_id: session_id + self.session_id = session_id def read_from_wire(self) -> Awaitable[Message]: """Read a whole message from the wire without parsing it.""" - return WireProtocol.read_message(self, self.iface, self.buffer) + return protocol.read_message(self, self.iface, self.buffer) if TYPE_CHECKING: @@ -160,7 +160,7 @@ class Context: msg_size = protobuf.encode(buffer, msg) - await WireProtocol.write_message( + await protocol.write_message( self.iface, Message( message_type=msg.MESSAGE_WIRE_TYPE, diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py index aa728c3f0..fc0e565cb 100644 --- a/core/src/trezor/wire/protocol.py +++ b/core/src/trezor/wire/protocol.py @@ -7,17 +7,15 @@ if TYPE_CHECKING: from trezorio import WireInterface -class WireProtocol: - async def read_message( - self, iface: WireInterface, buffer: utils.BufferType - ) -> Message: - if utils.USE_THP: - return await thp_v1.read_message(iface, buffer) - return await codec_v1.read_message(iface, buffer) +async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: + if utils.USE_THP: + return await thp_v1.read_message(iface, buffer) + return await codec_v1.read_message(iface, buffer) - async def write_message(self, iface: WireInterface, message: Message) -> None: - if utils.USE_THP: - await thp_v1.write_to_wire(iface, message) # TODO incomplete - return - await codec_v1.write_message(iface, message.type, message.data) + +async def write_message(iface: WireInterface, message: Message) -> None: + if utils.USE_THP: + await thp_v1.write_message(iface, message) return + await codec_v1.write_message(iface, message.type, message.data) + return diff --git a/core/src/trezor/wire/thp_session.py b/core/src/trezor/wire/thp_session.py index 3b68a5a3d..3fd38ec61 100644 --- a/core/src/trezor/wire/thp_session.py +++ b/core/src/trezor/wire/thp_session.py @@ -105,7 +105,7 @@ def is_active_session(session: SessionThpCache): def set_session_state(session: SessionThpCache, new_state: SessionState): - session.state = new_state.to_bytes(1, "big") + session.state = bytearray(new_state.to_bytes(1, "big")) def _get_id(iface: WireInterface, cid: int) -> bytes: diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 1c9e75d94..7085330b6 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -86,9 +86,8 @@ async def read_message_or_init_packet( report = firstReport while True: # Wait for an initial report - if firstReport is None: + if report is None: report = await _get_loop_wait_read(iface) - if report is None: raise ThpError("Reading failed unexpectedly, report is None.") @@ -96,7 +95,7 @@ async def read_message_or_init_packet( ctrl_byte, cid = ustruct.unpack(">BH", report) if cid == BROADCAST_CHANNEL_ID: - await _handle_broadcast(iface, ctrl_byte, report) + await _handle_broadcast(iface, ctrl_byte, report) # TODO await report = None continue @@ -258,7 +257,7 @@ async def write_message( async def write_to_wire( iface: WireInterface, header: InitHeader, payload: bytes ) -> None: - write = loop.wait(iface.iface_num() | io.POLL_WRITE) + loop_write = loop.wait(iface.iface_num() | io.POLL_WRITE) payload_length = len(payload) @@ -268,7 +267,7 @@ async def write_to_wire( # write initial report nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0) - await _write_report(write, iface, report) + await _write_report(loop_write, iface, report) # if we have more data to write, use continuation reports for it if nwritten < payload_length: @@ -276,7 +275,7 @@ async def write_to_wire( while nwritten < payload_length: nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten) - await _write_report(write, iface, report) + await _write_report(loop_write, iface, report) async def _write_report(write, iface: WireInterface, report: bytearray) -> None: @@ -287,7 +286,7 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None: return -async def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None: +async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None: if ctrl_byte != _CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in broadcast channel packet") length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:])