diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 80757fcde..b245ff5c9 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -49,7 +49,7 @@ class Channel(Context): super().__init__(iface, channel_cache.channel_id) self.channel_cache = channel_cache self.buffer: utils.BufferType - self.waiting_for_ack_timeout: loop.spawn | None + self.waiting_for_ack_timeout: loop.spawn | None = None self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read = 0 @@ -211,6 +211,10 @@ class Channel(Context): if state is ChannelState.TH2: await self._handle_state_TH2(message_length, sync_bit) return + if is_channel_state_pairing(state): + await self._handle_pairing(message_length) + return + raise ThpError("Unimplemented channel state") async def _handle_state_TH1( self, payload_length: int, message_length: int, sync_bit: int @@ -279,6 +283,9 @@ class Channel(Context): ) ) + async def _handle_pairing(self, message_length: int) -> None: + pass + def _handle_channel_message(self, message_length: int, message_type: int) -> None: buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH] @@ -581,6 +588,18 @@ def _is_ctrl_byte_ack(ctrl_byte: int) -> bool: return ctrl_byte & 0xEF == ACK_MESSAGE +def is_channel_state_pairing(state: int) -> bool: + if state in ( + ChannelState.TP1, + ChannelState.TP2, + ChannelState.TP3, + ChannelState.TP4, + ChannelState.TP5, + ): + return True + return False + + def _state_to_str(state: int) -> str: if state == ChannelState.ENCRYPTED_TRANSPORT: return "state: encrypted transport"