mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 07:20:56 +00:00
fixup! wip: single packet decryption (not finished) [no changelog]
This commit is contained in:
parent
38589d7d2e
commit
5180328bae
6
core/src/trezor/enums/__init__.py
generated
6
core/src/trezor/enums/__init__.py
generated
@ -39,8 +39,10 @@ if TYPE_CHECKING:
|
||||
PinMismatch = 12
|
||||
WipeCodeMismatch = 13
|
||||
InvalidSession = 14
|
||||
ThpUnallocatedSession = 15
|
||||
InvalidProtocol = 16
|
||||
DeviceIsBusy = 15
|
||||
ThpUnallocatedSession = 16
|
||||
InvalidProtocol = 17
|
||||
BufferError = 18
|
||||
FirmwareError = 99
|
||||
|
||||
class ButtonRequestType(IntEnum):
|
||||
|
@ -14,6 +14,11 @@ class SilentError(Exception):
|
||||
self.message = message
|
||||
|
||||
|
||||
class WireBufferError(Error):
|
||||
def __init__(self, message: str = "Buffer error") -> None:
|
||||
super().__init__(FailureType.BufferError, message)
|
||||
|
||||
|
||||
class UnexpectedMessage(Error):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(FailureType.UnexpectedMessage, message)
|
||||
|
@ -15,6 +15,7 @@ from storage.cache_thp import (
|
||||
clear_sessions_with_channel_id,
|
||||
)
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.wire.errors import WireBufferError
|
||||
|
||||
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
|
||||
from . import alternating_bit_protocol as ABP
|
||||
@ -81,8 +82,8 @@ class Channel:
|
||||
self.connection_context: PairingContext | None = None
|
||||
self.busy_decoder: crypto.BusyDecoder | None = None
|
||||
self.temp_crc: int | None = None
|
||||
self.temp_crc_compare: bytes | None = None
|
||||
self.temp_tag: bytes | None = None
|
||||
self.temp_crc_compare: bytearray | None = None
|
||||
self.temp_tag: bytearray | None = None
|
||||
|
||||
def clear(self) -> None:
|
||||
clear_sessions_with_channel_id(self.channel_id)
|
||||
@ -119,7 +120,7 @@ class Channel:
|
||||
|
||||
try:
|
||||
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||
except BufferError:
|
||||
except WireBufferError:
|
||||
pass # TODO ??
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
@ -170,67 +171,22 @@ class Channel:
|
||||
length = payload_length + INIT_HEADER_LENGTH
|
||||
try:
|
||||
buffer = memory_manager.get_new_read_buffer(cid, length)
|
||||
except BufferError:
|
||||
except WireBufferError:
|
||||
# TODO handle not encrypted/(short??), eg. ACK
|
||||
|
||||
self.fallback_decrypt = True
|
||||
self._prepare_busy_decoder()
|
||||
|
||||
self._prepare_fallback()
|
||||
|
||||
to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length)
|
||||
buf = memoryview(self.buffer)[:to_read_len]
|
||||
self.temp_crc = checksum.compute_int(data=packet[:INIT_HEADER_LENGTH])
|
||||
self.temp_crc_compare = bytearray(4)
|
||||
self.temp_tag = bytearray(16)
|
||||
utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH)
|
||||
|
||||
# TODO handle: CRC in init packet, CRC partially in init packet, CRC in some cont packet
|
||||
# instead of whole buf use only part without CRC
|
||||
#
|
||||
# bytes_read=0, buffer_len, payload_len
|
||||
# crc:
|
||||
# 1) payload_len >= buffer_len + CHKSUM_LEN -> return buffer_len
|
||||
# 2) payload_len == buffer_len -> return payload_len - CHKSUM_LEN
|
||||
# 3) payload_len > buffer_len -> return payload_len - CHKSUM_LEN
|
||||
#
|
||||
# noise tag:
|
||||
# 1) payload_len >= buffer_len + TAG_LEN + CHKSUM_LEN -> return buffer_len
|
||||
# 2) payload_len == buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN
|
||||
# 3) payload_len > buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN
|
||||
#
|
||||
|
||||
# CRC CHECK
|
||||
crc_copy_len: int = 0
|
||||
if payload_length > len(buf) + CHECKSUM_LENGTH:
|
||||
crc_copy_len = len(buf)
|
||||
elif payload_length == len(buf):
|
||||
crc_copy_len = payload_length - CHECKSUM_LENGTH
|
||||
crc_checksum_last_part = buf[-CHECKSUM_LENGTH:]
|
||||
offset = CHECKSUM_LENGTH - len(crc_checksum_last_part)
|
||||
utils.memcpy(self.temp_crc_compare, offset, crc_checksum_last_part, 0)
|
||||
elif payload_length > len(buf):
|
||||
crc_copy_len = payload_length - CHECKSUM_LENGTH
|
||||
crc_checksum_first_part = buf[
|
||||
-CHECKSUM_LENGTH + payload_length - len(buf)
|
||||
]
|
||||
utils.memcpy(self.temp_crc_compare, 0, crc_checksum_first_part, 0)
|
||||
else:
|
||||
raise Exception("Buffer should not be bigger than payload")
|
||||
self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc)
|
||||
self._handle_fallback_crc(buf)
|
||||
|
||||
# TAG CHECK
|
||||
assert self.busy_decoder is not None
|
||||
|
||||
if payload_length > len(buf) + TAG_LENGTH + CHECKSUM_LENGTH:
|
||||
self.busy_decoder.decrypt_part(buf)
|
||||
elif payload_length > len(buf):
|
||||
self.busy_decoder.decrypt_part(
|
||||
buf[: payload_length - TAG_LENGTH - CHECKSUM_LENGTH]
|
||||
)
|
||||
# TODO add part of the "tag from message" to compare
|
||||
else:
|
||||
raise Exception("Buffer should not be bigger than payload")
|
||||
|
||||
# TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess
|
||||
self._handle_fallback_decryption(buf)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("handle_init_packet - payload len: ", str(payload_length))
|
||||
@ -238,38 +194,55 @@ class Channel:
|
||||
|
||||
self._buffer_packet_data(buffer, packet, 0)
|
||||
|
||||
def _handle_fallback_crc(self, payload_length: int, buf: memoryview):
|
||||
if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH:
|
||||
def _handle_fallback_crc(self, buf: memoryview) -> None:
|
||||
assert self.temp_crc is not None
|
||||
assert self.temp_crc_compare is not None
|
||||
if self.expected_payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH:
|
||||
# The CRC checksum is not in this packet, compute crc over whole buffer
|
||||
self.temp_crc = checksum.compute_int(buf, self.temp_crc)
|
||||
elif payload_length >= len(buf) + self.bytes_read:
|
||||
elif self.expected_payload_length >= len(buf) + self.bytes_read:
|
||||
# At least a part of the CRC checksum is in this packet, compute CRC over
|
||||
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes
|
||||
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes (max 4)
|
||||
# as the checksum from message into temp_crc_compare
|
||||
crc_copy_len = payload_length - self.bytes_read - CHECKSUM_LENGTH
|
||||
crc_copy_len = (
|
||||
self.expected_payload_length - self.bytes_read - CHECKSUM_LENGTH
|
||||
)
|
||||
self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc)
|
||||
|
||||
crc_checksum = buf[
|
||||
payload_length - CHECKSUM_LENGTH - len(buf) - self.bytes_read :
|
||||
self.expected_payload_length
|
||||
- CHECKSUM_LENGTH
|
||||
- len(buf)
|
||||
- self.bytes_read :
|
||||
]
|
||||
offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:])
|
||||
utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0)
|
||||
else:
|
||||
raise Exception("Buffer (+bytes_read) should not be bigger than payload")
|
||||
|
||||
def _handle_fallback_decryption(self, payload_length: int, buf: memoryview):
|
||||
if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH:
|
||||
def _handle_fallback_decryption(self, buf: memoryview) -> None:
|
||||
assert self.busy_decoder is not None
|
||||
assert self.temp_tag is not None
|
||||
if (
|
||||
self.expected_payload_length
|
||||
> len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH
|
||||
):
|
||||
# The noise tag is not in this packet, decrypt the whole buffer
|
||||
self.busy_decoder.decrypt_part(buf)
|
||||
elif payload_length >= len(buf) + self.bytes_read:
|
||||
elif self.expected_payload_length >= len(buf) + self.bytes_read:
|
||||
# At least a part of the CRC checksum is in this packet, compute CRC over
|
||||
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes
|
||||
# as the checksum from message into temp_crc_compare
|
||||
dec_len = payload_length - self.bytes_read - TAG_LENGTH - CHECKSUM_LENGTH
|
||||
dec_len = (
|
||||
self.expected_payload_length
|
||||
- self.bytes_read
|
||||
- TAG_LENGTH
|
||||
- CHECKSUM_LENGTH
|
||||
)
|
||||
self.busy_decoder.decrypt_part(buf[:dec_len])
|
||||
|
||||
noise_tag = buf[
|
||||
payload_length
|
||||
self.expected_payload_length
|
||||
- CHECKSUM_LENGTH
|
||||
- TAG_LENGTH
|
||||
- len(buf)
|
||||
@ -293,7 +266,7 @@ class Channel:
|
||||
return
|
||||
try:
|
||||
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||
except BufferError:
|
||||
except WireBufferError:
|
||||
self.set_channel_state(ChannelState.INVALIDATED)
|
||||
pass # TODO handle device busy, channel kaput
|
||||
self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH)
|
||||
@ -317,7 +290,8 @@ class Channel:
|
||||
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
|
||||
return payload
|
||||
|
||||
def _prepare_busy_decoder(self) -> None:
|
||||
def _prepare_fallback(self) -> None:
|
||||
# prepare busy decoder
|
||||
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
|
||||
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
|
||||
|
||||
@ -326,11 +300,16 @@ class Channel:
|
||||
|
||||
self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive)
|
||||
|
||||
# prepare temp channel values
|
||||
self.temp_crc = 0
|
||||
self.temp_crc_compare = bytearray(4)
|
||||
self.temp_tag = bytearray(16)
|
||||
|
||||
def decrypt_buffer(
|
||||
self, message_length: int, offset: int = INIT_HEADER_LENGTH
|
||||
) -> None:
|
||||
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||
# if buffer is BufferError:
|
||||
# if buffer is WireBufferError:
|
||||
# pass # TODO handle deviceBUSY
|
||||
noise_buffer = memoryview(buffer)[
|
||||
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
||||
@ -388,7 +367,7 @@ class Channel:
|
||||
noise_payload_len = memory_manager.encode_into_buffer(
|
||||
buffer, msg, session_id
|
||||
)
|
||||
except BufferError:
|
||||
except WireBufferError:
|
||||
from trezor.messages import Failure, FailureType
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
@ -424,7 +403,7 @@ class Channel:
|
||||
) -> Awaitable[None] | None:
|
||||
payload_length = len(payload)
|
||||
buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
|
||||
# if buffer is BufferError:
|
||||
# if buffer is WireBufferError:
|
||||
# pass # TODO handle deviceBUSY
|
||||
|
||||
self._encrypt(buffer, payload_length)
|
||||
|
@ -3,6 +3,7 @@ from micropython import const
|
||||
|
||||
from storage.cache_thp import SESSION_ID_LENGTH
|
||||
from trezor import protobuf, utils
|
||||
from trezor.wire.errors import WireBufferError
|
||||
from trezor.wire.message_handler import get_msg_type
|
||||
|
||||
from . import ThpError
|
||||
@ -48,7 +49,7 @@ def get_existing_write_buffer(channel_id: int) -> memoryview:
|
||||
def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview:
|
||||
if is_locked():
|
||||
if not is_owner(channel_id):
|
||||
raise BufferError
|
||||
raise WireBufferError
|
||||
update_lock_time()
|
||||
else:
|
||||
update_lock(channel_id)
|
||||
@ -80,19 +81,19 @@ def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryvie
|
||||
|
||||
def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview:
|
||||
if not is_owner(channel_id):
|
||||
raise BufferError
|
||||
raise WireBufferError
|
||||
update_lock_time()
|
||||
|
||||
if buffer_type == _READ:
|
||||
global READ_BUFFER_SLICE
|
||||
if READ_BUFFER_SLICE is None:
|
||||
raise BufferError
|
||||
raise WireBufferError
|
||||
return READ_BUFFER_SLICE
|
||||
|
||||
if buffer_type == _WRITE:
|
||||
global WRITE_BUFFER_SLICE
|
||||
if WRITE_BUFFER_SLICE is None:
|
||||
raise BufferError
|
||||
raise WireBufferError
|
||||
return WRITE_BUFFER_SLICE
|
||||
|
||||
raise Exception("Invalid buffer_type")
|
||||
|
Loading…
Reference in New Issue
Block a user