diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index af88fa847f..2e6b78ef89 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -128,6 +128,11 @@ class Channel: if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: self._finish_message() + if self.fallback_decrypt: + # TODO Check CRC and if valid, check tag, if valid update nonces + self._finish_fallback() + # TODO self.write() failure device is busy - use channel buffer to send this failure message!! + return return received_message_handler.handle_received_message(self, buffer) elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: self.is_cont_packet_expected = True @@ -155,9 +160,6 @@ class Channel: _, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) self.expected_payload_length = payload_length - # packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:] - # The above could be used for single packet decryption - # If the channel does not "own" the buffer lock, decrypt first packet # TODO do it only when needed! # TODO FIX: If "_decrypt_single_packet_payload" is implemented, it will (possibly) break "decrypt_buffer" and nonces incrementation. @@ -188,6 +190,9 @@ class Channel: # TAG CHECK self._handle_fallback_decryption(buf) + self.bytes_read += to_read_len + return + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("handle_init_packet - payload len: ", str(payload_length)) self._log("handle_init_packet - buffer len: ", str(len(buffer))) @@ -261,8 +266,22 @@ class Channel: if not self.is_cont_packet_expected: raise ThpError("Continuation packet is not expected, ignoring") + if self.fallback_decrypt: - pass # TODO + to_read_len = min( + len(packet) - CONT_HEADER_LENGTH, + self.expected_payload_length - self.bytes_read, + ) + buf = memoryview(self.buffer)[:to_read_len] + utils.memcpy(buf, 0, packet, CONT_HEADER_LENGTH) + + # CRC CHECK + self._handle_fallback_crc(buf) + + # TAG CHECK + self._handle_fallback_decryption(buf) + + self.bytes_read += to_read_len return try: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) @@ -281,6 +300,7 @@ class Channel: self.expected_payload_length = 0 self.is_cont_packet_expected = False + def _finish_fallback(self) -> None: self.fallback_decrypt = False self.busy_decoder = None @@ -304,6 +324,7 @@ class Channel: self.temp_crc = 0 self.temp_crc_compare = bytearray(4) self.temp_tag = bytearray(16) + self.bytes_read = INIT_HEADER_LENGTH def decrypt_buffer( self, message_length: int, offset: int = INIT_HEADER_LENGTH