1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-18 11:21:11 +00:00

fixup! wip: single packet decryption (not finished) [no changelog]

This commit is contained in:
M1nd3r 2024-12-18 11:00:07 +01:00
parent 8314f8943b
commit 04f6a3e04a
4 changed files with 62 additions and 75 deletions

View File

@ -39,8 +39,10 @@ if TYPE_CHECKING:
PinMismatch = 12 PinMismatch = 12
WipeCodeMismatch = 13 WipeCodeMismatch = 13
InvalidSession = 14 InvalidSession = 14
ThpUnallocatedSession = 15 DeviceIsBusy = 15
InvalidProtocol = 16 ThpUnallocatedSession = 16
InvalidProtocol = 17
BufferError = 18
FirmwareError = 99 FirmwareError = 99
class ButtonRequestType(IntEnum): class ButtonRequestType(IntEnum):

View File

@ -14,6 +14,11 @@ class SilentError(Exception):
self.message = message self.message = message
class WireBufferError(Error):
def __init__(self, message: str = "Buffer error") -> None:
super().__init__(FailureType.BufferError, message)
class UnexpectedMessage(Error): class UnexpectedMessage(Error):
def __init__(self, message: str) -> None: def __init__(self, message: str) -> None:
super().__init__(FailureType.UnexpectedMessage, message) super().__init__(FailureType.UnexpectedMessage, message)

View File

@ -15,6 +15,7 @@ from storage.cache_thp import (
clear_sessions_with_channel_id, clear_sessions_with_channel_id,
) )
from trezor import log, loop, protobuf, utils, workflow from trezor import log, loop, protobuf, utils, workflow
from trezor.wire.errors import WireBufferError
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
from . import alternating_bit_protocol as ABP from . import alternating_bit_protocol as ABP
@ -81,8 +82,8 @@ class Channel:
self.connection_context: PairingContext | None = None self.connection_context: PairingContext | None = None
self.busy_decoder: crypto.BusyDecoder | None = None self.busy_decoder: crypto.BusyDecoder | None = None
self.temp_crc: int | None = None self.temp_crc: int | None = None
self.temp_crc_compare: bytes | None = None self.temp_crc_compare: bytearray | None = None
self.temp_tag: bytes | None = None self.temp_tag: bytearray | None = None
def clear(self) -> None: def clear(self) -> None:
clear_sessions_with_channel_id(self.channel_id) clear_sessions_with_channel_id(self.channel_id)
@ -119,7 +120,7 @@ class Channel:
try: try:
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
except BufferError: except WireBufferError:
pass # TODO ?? pass # TODO ??
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
@ -170,67 +171,22 @@ class Channel:
length = payload_length + INIT_HEADER_LENGTH length = payload_length + INIT_HEADER_LENGTH
try: try:
buffer = memory_manager.get_new_read_buffer(cid, length) buffer = memory_manager.get_new_read_buffer(cid, length)
except BufferError: except WireBufferError:
# TODO handle not encrypted/(short??), eg. ACK # TODO handle not encrypted/(short??), eg. ACK
self.fallback_decrypt = True self.fallback_decrypt = True
self._prepare_busy_decoder()
self._prepare_fallback()
to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length) to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length)
buf = memoryview(self.buffer)[:to_read_len] buf = memoryview(self.buffer)[:to_read_len]
self.temp_crc = checksum.compute_int(data=packet[:INIT_HEADER_LENGTH])
self.temp_crc_compare = bytearray(4)
self.temp_tag = bytearray(16)
utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH) utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH)
# TODO handle: CRC in init packet, CRC partially in init packet, CRC in some cont packet
# instead of whole buf use only part without CRC
#
# bytes_read=0, buffer_len, payload_len
# crc:
# 1) payload_len >= buffer_len + CHKSUM_LEN -> return buffer_len
# 2) payload_len == buffer_len -> return payload_len - CHKSUM_LEN
# 3) payload_len > buffer_len -> return payload_len - CHKSUM_LEN
#
# noise tag:
# 1) payload_len >= buffer_len + TAG_LEN + CHKSUM_LEN -> return buffer_len
# 2) payload_len == buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN
# 3) payload_len > buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN
#
# CRC CHECK # CRC CHECK
crc_copy_len: int = 0 self._handle_fallback_crc(buf)
if payload_length > len(buf) + CHECKSUM_LENGTH:
crc_copy_len = len(buf)
elif payload_length == len(buf):
crc_copy_len = payload_length - CHECKSUM_LENGTH
crc_checksum_last_part = buf[-CHECKSUM_LENGTH:]
offset = CHECKSUM_LENGTH - len(crc_checksum_last_part)
utils.memcpy(self.temp_crc_compare, offset, crc_checksum_last_part, 0)
elif payload_length > len(buf):
crc_copy_len = payload_length - CHECKSUM_LENGTH
crc_checksum_first_part = buf[
-CHECKSUM_LENGTH + payload_length - len(buf)
]
utils.memcpy(self.temp_crc_compare, 0, crc_checksum_first_part, 0)
else:
raise Exception("Buffer should not be bigger than payload")
self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc)
# TAG CHECK # TAG CHECK
assert self.busy_decoder is not None self._handle_fallback_decryption(buf)
if payload_length > len(buf) + TAG_LENGTH + CHECKSUM_LENGTH:
self.busy_decoder.decrypt_part(buf)
elif payload_length > len(buf):
self.busy_decoder.decrypt_part(
buf[: payload_length - TAG_LENGTH - CHECKSUM_LENGTH]
)
# TODO add part of the "tag from message" to compare
else:
raise Exception("Buffer should not be bigger than payload")
# TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("handle_init_packet - payload len: ", str(payload_length)) self._log("handle_init_packet - payload len: ", str(payload_length))
@ -238,38 +194,55 @@ class Channel:
self._buffer_packet_data(buffer, packet, 0) self._buffer_packet_data(buffer, packet, 0)
def _handle_fallback_crc(self, payload_length: int, buf: memoryview): def _handle_fallback_crc(self, buf: memoryview) -> None:
if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH: assert self.temp_crc is not None
assert self.temp_crc_compare is not None
if self.expected_payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH:
# The CRC checksum is not in this packet, compute crc over whole buffer # The CRC checksum is not in this packet, compute crc over whole buffer
self.temp_crc = checksum.compute_int(buf, self.temp_crc) self.temp_crc = checksum.compute_int(buf, self.temp_crc)
elif payload_length >= len(buf) + self.bytes_read: elif self.expected_payload_length >= len(buf) + self.bytes_read:
# At least a part of the CRC checksum is in this packet, compute CRC over # At least a part of the CRC checksum is in this packet, compute CRC over
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes # first (max(0, crc_copy_len)) bytes and add the rest of the bytes (max 4)
# as the checksum from message into temp_crc_compare # as the checksum from message into temp_crc_compare
crc_copy_len = payload_length - self.bytes_read - CHECKSUM_LENGTH crc_copy_len = (
self.expected_payload_length - self.bytes_read - CHECKSUM_LENGTH
)
self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc) self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc)
crc_checksum = buf[ crc_checksum = buf[
payload_length - CHECKSUM_LENGTH - len(buf) - self.bytes_read : self.expected_payload_length
- CHECKSUM_LENGTH
- len(buf)
- self.bytes_read :
] ]
offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:]) offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:])
utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0) utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0)
else: else:
raise Exception("Buffer (+bytes_read) should not be bigger than payload") raise Exception("Buffer (+bytes_read) should not be bigger than payload")
def _handle_fallback_decryption(self, payload_length: int, buf: memoryview): def _handle_fallback_decryption(self, buf: memoryview) -> None:
if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH: assert self.busy_decoder is not None
assert self.temp_tag is not None
if (
self.expected_payload_length
> len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH
):
# The noise tag is not in this packet, decrypt the whole buffer # The noise tag is not in this packet, decrypt the whole buffer
self.busy_decoder.decrypt_part(buf) self.busy_decoder.decrypt_part(buf)
elif payload_length >= len(buf) + self.bytes_read: elif self.expected_payload_length >= len(buf) + self.bytes_read:
# At least a part of the CRC checksum is in this packet, compute CRC over # At least a part of the CRC checksum is in this packet, compute CRC over
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes # first (max(0, crc_copy_len)) bytes and add the rest of the bytes
# as the checksum from message into temp_crc_compare # as the checksum from message into temp_crc_compare
dec_len = payload_length - self.bytes_read - TAG_LENGTH - CHECKSUM_LENGTH dec_len = (
self.expected_payload_length
- self.bytes_read
- TAG_LENGTH
- CHECKSUM_LENGTH
)
self.busy_decoder.decrypt_part(buf[:dec_len]) self.busy_decoder.decrypt_part(buf[:dec_len])
noise_tag = buf[ noise_tag = buf[
payload_length self.expected_payload_length
- CHECKSUM_LENGTH - CHECKSUM_LENGTH
- TAG_LENGTH - TAG_LENGTH
- len(buf) - len(buf)
@ -293,7 +266,7 @@ class Channel:
return return
try: try:
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
except BufferError: except WireBufferError:
self.set_channel_state(ChannelState.INVALIDATED) self.set_channel_state(ChannelState.INVALIDATED)
pass # TODO handle device busy, channel kaput pass # TODO handle device busy, channel kaput
self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH) self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH)
@ -317,7 +290,8 @@ class Channel:
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) # crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload return payload
def _prepare_busy_decoder(self) -> None: def _prepare_fallback(self) -> None:
# prepare busy decoder
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
@ -326,11 +300,16 @@ class Channel:
self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive) self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive)
# prepare temp channel values
self.temp_crc = 0
self.temp_crc_compare = bytearray(4)
self.temp_tag = bytearray(16)
def decrypt_buffer( def decrypt_buffer(
self, message_length: int, offset: int = INIT_HEADER_LENGTH self, message_length: int, offset: int = INIT_HEADER_LENGTH
) -> None: ) -> None:
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
# if buffer is BufferError: # if buffer is WireBufferError:
# pass # TODO handle deviceBUSY # pass # TODO handle deviceBUSY
noise_buffer = memoryview(buffer)[ noise_buffer = memoryview(buffer)[
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
@ -387,7 +366,7 @@ class Channel:
noise_payload_len = memory_manager.encode_into_buffer( noise_payload_len = memory_manager.encode_into_buffer(
buffer, msg, session_id buffer, msg, session_id
) )
except BufferError: except WireBufferError:
from trezor.messages import Failure, FailureType from trezor.messages import Failure, FailureType
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
@ -421,7 +400,7 @@ class Channel:
def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]: def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]:
payload_length = len(payload) payload_length = len(payload)
buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int()) buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
# if buffer is BufferError: # if buffer is WireBufferError:
# pass # TODO handle deviceBUSY # pass # TODO handle deviceBUSY
self._encrypt(buffer, payload_length) self._encrypt(buffer, payload_length)

View File

@ -3,6 +3,7 @@ from micropython import const
from storage.cache_thp import SESSION_ID_LENGTH from storage.cache_thp import SESSION_ID_LENGTH
from trezor import protobuf, utils from trezor import protobuf, utils
from trezor.wire.errors import WireBufferError
from trezor.wire.message_handler import get_msg_type from trezor.wire.message_handler import get_msg_type
from . import ThpError from . import ThpError
@ -48,7 +49,7 @@ def get_existing_write_buffer(channel_id: int) -> memoryview:
def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview: def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview:
if is_locked(): if is_locked():
if not is_owner(channel_id): if not is_owner(channel_id):
raise BufferError raise WireBufferError
update_lock_time() update_lock_time()
else: else:
update_lock(channel_id) update_lock(channel_id)
@ -80,19 +81,19 @@ def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryvie
def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview: def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview:
if not is_owner(channel_id): if not is_owner(channel_id):
raise BufferError raise WireBufferError
update_lock_time() update_lock_time()
if buffer_type == _READ: if buffer_type == _READ:
global READ_BUFFER_SLICE global READ_BUFFER_SLICE
if READ_BUFFER_SLICE is None: if READ_BUFFER_SLICE is None:
raise BufferError raise WireBufferError
return READ_BUFFER_SLICE return READ_BUFFER_SLICE
if buffer_type == _WRITE: if buffer_type == _WRITE:
global WRITE_BUFFER_SLICE global WRITE_BUFFER_SLICE
if WRITE_BUFFER_SLICE is None: if WRITE_BUFFER_SLICE is None:
raise BufferError raise WireBufferError
return WRITE_BUFFER_SLICE return WRITE_BUFFER_SLICE
raise Exception("Invalid buffer_type") raise Exception("Invalid buffer_type")