diff --git a/core/src/trezor/wire/thp/channel_manager.py b/core/src/trezor/wire/thp/channel_manager.py index 75de9485f9..395d86e183 100644 --- a/core/src/trezor/wire/thp/channel_manager.py +++ b/core/src/trezor/wire/thp/channel_manager.py @@ -8,6 +8,10 @@ from .channel import Channel if TYPE_CHECKING: from trezorio import WireInterface +if __debug__: + from trezor import log + +CHANNELS_LOADED: bool = False def create_new_channel(iface: WireInterface) -> Channel: """ @@ -19,12 +23,21 @@ def create_new_channel(iface: WireInterface) -> Channel: return channel -def load_cached_channels() -> dict[int, Channel]: +def load_cached_channels(channels_dict: dict[int, Channel]) -> None: """ Returns all allocated channels from cache. """ - channels: dict[int, Channel] = {} + global CHANNELS_LOADED + + if CHANNELS_LOADED: + if __debug__: + log.debug(__name__, "Channels already loaded, process skipped.") + return + cached_channels = cache_thp.get_all_allocated_channels() for channel in cached_channels: - channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel) - return channels + channel_id = int.from_bytes(channel.channel_id, "big") + channels_dict[channel_id] = Channel(channel) + if __debug__: + log.debug(__name__, "Channels loaded from cache.") + CHANNELS_LOADED = True diff --git a/core/src/trezor/wire/thp/thp_main.py b/core/src/trezor/wire/thp/thp_main.py index 2e16a0f8b3..0a31162d0e 100644 --- a/core/src/trezor/wire/thp/thp_main.py +++ b/core/src/trezor/wire/thp/thp_main.py @@ -35,37 +35,41 @@ _CHANNELS: dict[int, Channel] = {} async def thp_main_loop(iface: WireInterface) -> None: global _CHANNELS - _CHANNELS = channel_manager.load_cached_channels() + channel_manager.load_cached_channels(_CHANNELS) read = loop.wait(iface.iface_num() | io.POLL_READ) packet = bytearray(iface.RX_PACKET_LEN) - while True: - try: - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "thp_main_loop") - packet_len = await read - assert packet_len == len(packet) - iface.read(packet, 0) + try: + while True: + try: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug( + __name__, f"thp_main_loop from iface: {iface.iface_num()}" + ) + packet_len = await read + assert packet_len == len(packet) + iface.read(packet, 0) - if _get_ctrl_byte(packet) == CODEC_V1: - await _handle_codec_v1(iface, packet) - continue + if _get_ctrl_byte(packet) == CODEC_V1: + await _handle_codec_v1(iface, packet) + continue - cid = ustruct.unpack(">BH", packet)[1] + cid = ustruct.unpack(">BH", packet)[1] - if cid == BROADCAST_CHANNEL_ID: - await _handle_broadcast(iface, packet) - continue + if cid == BROADCAST_CHANNEL_ID: + await _handle_broadcast(iface, packet) + continue - if cid in _CHANNELS: - await _handle_allocated(iface, cid, packet) - else: - await _handle_unallocated(iface, cid, packet) - - except ThpError as e: - if __debug__: - log.exception(__name__, e) + if cid in _CHANNELS: + await _handle_allocated(iface, cid, packet) + else: + await _handle_unallocated(iface, cid, packet) + except ThpError as e: + if __debug__: + log.exception(__name__, e) + finally: + channel_manager.CHANNELS_LOADED = False async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None: # If the received packet is not an initial codec_v1 packet, do not send error message