From 108d9ec89b43bbe8a782dc835ca1f95afe13b538 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 11 Mar 2024 15:46:08 +0100 Subject: [PATCH] Lower the number of style prebuild errors 2 --- core/src/storage/cache_thp.py | 11 +++++++---- core/src/trezor/wire/protocol.py | 6 ++++-- core/src/trezor/wire/thp_session.py | 7 ++++--- core/src/trezor/wire/thp_v1.py | 15 +++++++++++---- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 828017e2bb..d9d39bd238 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -56,6 +56,7 @@ class SessionThpCache(DataCache): # TODO implement, this is just copied Session def clear(self) -> None: super().clear() + self.state = 0 # Set state to UNALLOCATED self.last_usage = 0 self.session_id[:] = b"" @@ -93,10 +94,12 @@ _session_usage_counter = 0 cid_counter: int = 4659 -def get_active_session_id(): - if get_active_session() is None: +def get_active_session_id() -> bytearray | None: + active_session = get_active_session() + + if active_session is None: return None - return get_active_session().session_id + return active_session.session_id def get_active_session() -> SessionThpCache | None: @@ -168,7 +171,7 @@ def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: new_auth_session_index = get_least_recently_used_authetnicated_session_index() _SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx] - _UNAUTHENTICATED_SESSIONS[unauth_session_idx] = None + _UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear() _session_usage_counter += 1 _SESSIONS[new_auth_session_index].last_usage = _session_usage_counter diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py index 95d27b4f1f..aa728c3f0c 100644 --- a/core/src/trezor/wire/protocol.py +++ b/core/src/trezor/wire/protocol.py @@ -17,5 +17,7 @@ class WireProtocol: async def write_message(self, iface: WireInterface, message: Message) -> None: if utils.USE_THP: - return thp_v1.write_to_wire(iface, message) - return codec_v1.write_message(iface, message.type, message.data) + await thp_v1.write_to_wire(iface, message) # TODO incomplete + return + await codec_v1.write_message(iface, message.type, message.data) + return diff --git a/core/src/trezor/wire/thp_session.py b/core/src/trezor/wire/thp_session.py index 863c5dd7e8..3b68a5a3d6 100644 --- a/core/src/trezor/wire/thp_session.py +++ b/core/src/trezor/wire/thp_session.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from ubinascii import hexlify if TYPE_CHECKING: + from trezorio import WireInterface from enum import IntEnum else: IntEnum = object @@ -26,8 +27,8 @@ class SessionState(IntEnum): def create_autenticated_session(unauthenticated_session: SessionThpCache): - storage_thp_cache.start_session() # TODO something like this but for THP - raise + # storage_thp_cache.start_session() - TODO something like this but for THP + raise NotImplementedError("Secure channel is not implemented, yet.") def create_new_unauthenticated_session(iface: WireInterface, cid: int): @@ -107,7 +108,7 @@ def set_session_state(session: SessionThpCache, new_state: SessionState): session.state = new_state.to_bytes(1, "big") -def _get_id(iface: WireInterface, cid: int) -> bytearray: +def _get_id(iface: WireInterface, cid: int) -> bytes: return ustruct.pack(">HH", iface.iface_num(), cid) diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 6e609bdd41..1c9e75d94d 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -88,6 +88,10 @@ async def read_message_or_init_packet( # Wait for an initial report if firstReport is None: report = await _get_loop_wait_read(iface) + + if report is None: + raise ThpError("Reading failed unexpectedly, report is None.") + # Channel multiplexing ctrl_byte, cid = ustruct.unpack(">BH", report) @@ -104,9 +108,6 @@ async def read_message_or_init_packet( report = None continue - if report is None: - raise ThpError("Reading failed unexpectedly, report is None.") - payload_length = ustruct.unpack(">H", report[3:])[0] payload = _get_buffer_for_payload(payload_length, buffer) header = InitHeader(ctrl_byte, cid, payload_length) @@ -134,6 +135,9 @@ async def read_message_or_init_packet( report = None continue + if session is None: + raise ThpError("Invalid session!") + # Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here # Synchronization process @@ -219,6 +223,9 @@ async def write_message( iface: WireInterface, message: Message, is_retransmission: bool = False ) -> None: session = THP.get_session_from_id(message.session_id) + if session is None: + raise ThpError("Invalid session") + cid = THP.get_cid(session) payload = message.type.to_bytes(2, "big") + message.data payload_length = len(payload) @@ -355,7 +362,7 @@ def _get_new_channel_id() -> int: return THP.get_next_channel_id() -def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytearray) -> bool: +def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool: data_checksum = _compute_checksum_bytes(data) return checksum == data_checksum