diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 0e76bb3f27..287f3b3a43 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -60,17 +60,18 @@ class Channel: """ def __init__(self, channel_cache: ChannelCache) -> None: - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "channel initialization") # Channel properties + self.channel_id: bytes = channel_cache.channel_id + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("channel initialization") self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) self.channel_cache: ChannelCache = channel_cache - self.channel_id: bytes = channel_cache.channel_id # Shared variables self.buffer: utils.BufferType = bytearray(self.iface.TX_PACKET_LEN) self.fallback_decrypt: bool = False + self.fallback_session_id: int | None = None self.bytes_read: int = 0 self.expected_payload_length: int = 0 self.is_cont_packet_expected: bool = False @@ -139,7 +140,9 @@ class Channel: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("receive packet") - self._handle_received_packet(packet) + task = self._handle_received_packet(packet) + if task is not None: + return task if self.expected_payload_length == 0: # Reading failed TODO from trezor.wire.thp import ThpErrorType @@ -148,13 +151,24 @@ class Channel: try: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) - except WireBufferError: - pass # TODO ?? - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - try: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("self.buffer: ", get_bytes_as_str(buffer)) - except Exception: - pass # TODO handle nicer - happens in fallback_decrypt + except WireBufferError: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log( + "getting read buffer failed - ", str(WireBufferError.__name__) + ) + pass # TODO ?? + if self.fallback_decrypt and self.expected_payload_length == self.bytes_read: + self._finish_fallback() + from trezor.messages import Failure + from trezor.enums import FailureType + + return self.write( + Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"), + session_id=self.fallback_session_id or 0, + fallback=True, + ) if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: self._finish_message() @@ -166,21 +180,30 @@ class Channel: return received_message_handler.handle_received_message(self, buffer) elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: self.is_cont_packet_expected = True + self._log( + "CONT EXPECTED - read/expected:", + str(self.bytes_read) + + "/" + + str(self.expected_payload_length + INIT_HEADER_LENGTH), + ) else: raise ThpError( "Read more bytes than is the expected length of the message!" ) return None - def _handle_received_packet(self, packet: utils.BufferType) -> None: + def _handle_received_packet( + self, packet: utils.BufferType + ) -> Awaitable[None] | None: ctrl_byte = packet[0] if control_byte.is_continuation(ctrl_byte): self._handle_cont_packet(packet) - return - self._handle_init_packet(packet) + return None + return self._handle_init_packet(packet) def _handle_init_packet(self, packet: utils.BufferType) -> None: self.fallback_decrypt = False + self.fallback_session_id = None self.bytes_read = 0 self.expected_payload_length = 0 @@ -204,11 +227,15 @@ class Channel: try: buffer = memory_manager.get_new_read_buffer(cid, length) except WireBufferError: + self.fallback_decrypt = True # TODO handle not encrypted/(short??), eg. ACK - self.fallback_decrypt = True - try: + if not self._can_fallback(): + raise Exception( + "Channel is in a state that does not support fallback." + ) + self._log("Started fallback read") self._prepare_fallback() except Exception: self.fallback_decrypt = False @@ -220,7 +247,7 @@ class Channel: log.debug( __name__, "FAILED TO FALLBACK: %s", hexlify(packet).decode() ) - return + return None to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length) buf = memoryview(self.buffer)[:to_read_len] @@ -229,17 +256,23 @@ class Channel: # CRC CHECK self._handle_fallback_crc(buf) + # Handle ACK + if control_byte.is_ack(packet[0]): + ack_bit = (packet[0] & 0x08) >> 3 + return received_message_handler._handle_ack(self, ack_bit) + # TAG CHECK self._handle_fallback_decryption(buf) self.bytes_read += to_read_len - return + return None if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("handle_init_packet - payload len: ", str(payload_length)) self._log("handle_init_packet - buffer len: ", str(len(buffer))) self._buffer_packet_data(buffer, packet, 0) + return None def _handle_fallback_crc(self, buf: memoryview) -> None: assert self.temp_crc is not None @@ -303,6 +336,8 @@ class Channel: utils.memcpy(self.temp_tag, offset, noise_tag, 0) else: raise Exception("Buffer (+bytes_read) should not be bigger than payload") + if self.fallback_session_id is None: + self.fallback_session_id = buf[0] def _handle_cont_packet(self, packet: utils.BufferType) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: @@ -347,6 +382,7 @@ class Channel: def _finish_fallback(self) -> None: self.fallback_decrypt = False self.busy_decoder = None + self._log("Finish fallback") def _decrypt_single_packet_payload( self, payload: utils.BufferType @@ -419,6 +455,7 @@ class Channel: msg: protobuf.MessageType, session_id: int = 0, force: bool = False, + fallback: bool = False, ) -> None: if __debug__ and utils.EMULATOR: self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg)) @@ -428,7 +465,10 @@ class Channel: payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + INIT_HEADER_LENGTH try: - buffer = memory_manager.get_new_write_buffer(cid, length) + if fallback: + buffer = self.buffer + else: + buffer = memory_manager.get_new_write_buffer(cid, length) noise_payload_len = memory_manager.encode_into_buffer( buffer, msg, session_id ) @@ -448,7 +488,9 @@ class Channel: session_id, ) self.set_channel_state(ChannelState.INVALIDATED) - task = self._write_and_encrypt(noise_payload_len, force) + task = self._write_and_encrypt( + noise_payload_len=noise_payload_len, force=force, fallback=fallback + ) if task is not None: await task @@ -465,9 +507,15 @@ class Channel: ) def _write_and_encrypt( - self, noise_payload_len: int, force: bool = False + self, + noise_payload_len: int, + force: bool = False, + fallback: bool = False, ) -> Awaitable[None] | None: - buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int()) + if fallback: + buffer = self.buffer + else: + buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int()) # if buffer is WireBufferError: # pass # TODO handle deviceBUSY @@ -478,6 +526,18 @@ class Channel: self.write_task_spawn.close() # UPS TODO might break something print("\nCLOSED\n") self._prepare_write() + if fallback: + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log( + "Writing FALLBACK message (written only once without async or retransmission)." + ) + + return self._write_encrypted_payload_loop( + ctrl_byte=ENCRYPTED, + payload=memoryview(buffer[:payload_length]), + only_once=True, + ) + if force: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("Writing FORCE message (without async or retransmission).") @@ -497,7 +557,7 @@ class Channel: ABP.set_sending_allowed(self.channel_cache, False) async def _write_encrypted_payload_loop( - self, ctrl_byte: int, payload: bytes + self, ctrl_byte: int, payload: bytes, only_once: bool = False ) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("write_encrypted_payload_loop") @@ -507,7 +567,10 @@ class Channel: ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit) header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len) self.transmission_loop = TransmissionLoop(self, header, payload) - await self.transmission_loop.start() + if only_once: + await self.transmission_loop.start(max_retransmission_count=1) + else: + await self.transmission_loop.start() ABP.set_send_seq_bit_to_opposite(self.channel_cache) @@ -516,7 +579,7 @@ class Channel: if self._can_clear_loop(): if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("clearing loop from channel") - + pass loop.clear() def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: @@ -549,6 +612,14 @@ class Channel: not workflow.tasks ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT + def _can_fallback(self) -> bool: + state = self.get_channel_state() + return state not in [ + ChannelState.TH1, + ChannelState.TH2, + ChannelState.UNALLOCATED, + ] + if __debug__: def _log(self, text_1: str, text_2: str = "") -> None: