mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-22 05:10:56 +00:00
fixup! wip: single packet decryption (not finished) [no changelog]
This commit is contained in:
parent
7d3a45fc99
commit
5c5c8df83e
6
core/src/trezor/enums/__init__.py
generated
6
core/src/trezor/enums/__init__.py
generated
@ -39,8 +39,10 @@ if TYPE_CHECKING:
|
|||||||
PinMismatch = 12
|
PinMismatch = 12
|
||||||
WipeCodeMismatch = 13
|
WipeCodeMismatch = 13
|
||||||
InvalidSession = 14
|
InvalidSession = 14
|
||||||
ThpUnallocatedSession = 15
|
DeviceIsBusy = 15
|
||||||
InvalidProtocol = 16
|
ThpUnallocatedSession = 16
|
||||||
|
InvalidProtocol = 17
|
||||||
|
BufferError = 18
|
||||||
FirmwareError = 99
|
FirmwareError = 99
|
||||||
|
|
||||||
class ButtonRequestType(IntEnum):
|
class ButtonRequestType(IntEnum):
|
||||||
|
@ -14,6 +14,11 @@ class SilentError(Exception):
|
|||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class WireBufferError(Error):
|
||||||
|
def __init__(self, message: str = "Buffer error") -> None:
|
||||||
|
super().__init__(FailureType.BufferError, message)
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedMessage(Error):
|
class UnexpectedMessage(Error):
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
super().__init__(FailureType.UnexpectedMessage, message)
|
super().__init__(FailureType.UnexpectedMessage, message)
|
||||||
|
@ -15,6 +15,7 @@ from storage.cache_thp import (
|
|||||||
clear_sessions_with_channel_id,
|
clear_sessions_with_channel_id,
|
||||||
)
|
)
|
||||||
from trezor import log, loop, protobuf, utils, workflow
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
|
from trezor.wire.errors import WireBufferError
|
||||||
|
|
||||||
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
|
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
|
||||||
from . import alternating_bit_protocol as ABP
|
from . import alternating_bit_protocol as ABP
|
||||||
@ -81,8 +82,8 @@ class Channel:
|
|||||||
self.connection_context: PairingContext | None = None
|
self.connection_context: PairingContext | None = None
|
||||||
self.busy_decoder: crypto.BusyDecoder | None = None
|
self.busy_decoder: crypto.BusyDecoder | None = None
|
||||||
self.temp_crc: int | None = None
|
self.temp_crc: int | None = None
|
||||||
self.temp_crc_compare: bytes | None = None
|
self.temp_crc_compare: bytearray | None = None
|
||||||
self.temp_tag: bytes | None = None
|
self.temp_tag: bytearray | None = None
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
clear_sessions_with_channel_id(self.channel_id)
|
clear_sessions_with_channel_id(self.channel_id)
|
||||||
@ -119,7 +120,7 @@ class Channel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||||
except BufferError:
|
except WireBufferError:
|
||||||
pass # TODO ??
|
pass # TODO ??
|
||||||
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
@ -170,67 +171,22 @@ class Channel:
|
|||||||
length = payload_length + INIT_HEADER_LENGTH
|
length = payload_length + INIT_HEADER_LENGTH
|
||||||
try:
|
try:
|
||||||
buffer = memory_manager.get_new_read_buffer(cid, length)
|
buffer = memory_manager.get_new_read_buffer(cid, length)
|
||||||
except BufferError:
|
except WireBufferError:
|
||||||
# TODO handle not encrypted/(short??), eg. ACK
|
# TODO handle not encrypted/(short??), eg. ACK
|
||||||
|
|
||||||
self.fallback_decrypt = True
|
self.fallback_decrypt = True
|
||||||
self._prepare_busy_decoder()
|
|
||||||
|
self._prepare_fallback()
|
||||||
|
|
||||||
to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length)
|
to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length)
|
||||||
buf = memoryview(self.buffer)[:to_read_len]
|
buf = memoryview(self.buffer)[:to_read_len]
|
||||||
self.temp_crc = checksum.compute_int(data=packet[:INIT_HEADER_LENGTH])
|
|
||||||
self.temp_crc_compare = bytearray(4)
|
|
||||||
self.temp_tag = bytearray(16)
|
|
||||||
utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH)
|
utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH)
|
||||||
|
|
||||||
# TODO handle: CRC in init packet, CRC partially in init packet, CRC in some cont packet
|
|
||||||
# instead of whole buf use only part without CRC
|
|
||||||
#
|
|
||||||
# bytes_read=0, buffer_len, payload_len
|
|
||||||
# crc:
|
|
||||||
# 1) payload_len >= buffer_len + CHKSUM_LEN -> return buffer_len
|
|
||||||
# 2) payload_len == buffer_len -> return payload_len - CHKSUM_LEN
|
|
||||||
# 3) payload_len > buffer_len -> return payload_len - CHKSUM_LEN
|
|
||||||
#
|
|
||||||
# noise tag:
|
|
||||||
# 1) payload_len >= buffer_len + TAG_LEN + CHKSUM_LEN -> return buffer_len
|
|
||||||
# 2) payload_len == buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN
|
|
||||||
# 3) payload_len > buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN
|
|
||||||
#
|
|
||||||
|
|
||||||
# CRC CHECK
|
# CRC CHECK
|
||||||
crc_copy_len: int = 0
|
self._handle_fallback_crc(buf)
|
||||||
if payload_length > len(buf) + CHECKSUM_LENGTH:
|
|
||||||
crc_copy_len = len(buf)
|
|
||||||
elif payload_length == len(buf):
|
|
||||||
crc_copy_len = payload_length - CHECKSUM_LENGTH
|
|
||||||
crc_checksum_last_part = buf[-CHECKSUM_LENGTH:]
|
|
||||||
offset = CHECKSUM_LENGTH - len(crc_checksum_last_part)
|
|
||||||
utils.memcpy(self.temp_crc_compare, offset, crc_checksum_last_part, 0)
|
|
||||||
elif payload_length > len(buf):
|
|
||||||
crc_copy_len = payload_length - CHECKSUM_LENGTH
|
|
||||||
crc_checksum_first_part = buf[
|
|
||||||
-CHECKSUM_LENGTH + payload_length - len(buf)
|
|
||||||
]
|
|
||||||
utils.memcpy(self.temp_crc_compare, 0, crc_checksum_first_part, 0)
|
|
||||||
else:
|
|
||||||
raise Exception("Buffer should not be bigger than payload")
|
|
||||||
self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc)
|
|
||||||
|
|
||||||
# TAG CHECK
|
# TAG CHECK
|
||||||
assert self.busy_decoder is not None
|
self._handle_fallback_decryption(buf)
|
||||||
|
|
||||||
if payload_length > len(buf) + TAG_LENGTH + CHECKSUM_LENGTH:
|
|
||||||
self.busy_decoder.decrypt_part(buf)
|
|
||||||
elif payload_length > len(buf):
|
|
||||||
self.busy_decoder.decrypt_part(
|
|
||||||
buf[: payload_length - TAG_LENGTH - CHECKSUM_LENGTH]
|
|
||||||
)
|
|
||||||
# TODO add part of the "tag from message" to compare
|
|
||||||
else:
|
|
||||||
raise Exception("Buffer should not be bigger than payload")
|
|
||||||
|
|
||||||
# TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess
|
|
||||||
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("handle_init_packet - payload len: ", str(payload_length))
|
self._log("handle_init_packet - payload len: ", str(payload_length))
|
||||||
@ -238,38 +194,55 @@ class Channel:
|
|||||||
|
|
||||||
self._buffer_packet_data(buffer, packet, 0)
|
self._buffer_packet_data(buffer, packet, 0)
|
||||||
|
|
||||||
def _handle_fallback_crc(self, payload_length: int, buf: memoryview):
|
def _handle_fallback_crc(self, buf: memoryview) -> None:
|
||||||
if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH:
|
assert self.temp_crc is not None
|
||||||
|
assert self.temp_crc_compare is not None
|
||||||
|
if self.expected_payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH:
|
||||||
# The CRC checksum is not in this packet, compute crc over whole buffer
|
# The CRC checksum is not in this packet, compute crc over whole buffer
|
||||||
self.temp_crc = checksum.compute_int(buf, self.temp_crc)
|
self.temp_crc = checksum.compute_int(buf, self.temp_crc)
|
||||||
elif payload_length >= len(buf) + self.bytes_read:
|
elif self.expected_payload_length >= len(buf) + self.bytes_read:
|
||||||
# At least a part of the CRC checksum is in this packet, compute CRC over
|
# At least a part of the CRC checksum is in this packet, compute CRC over
|
||||||
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes
|
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes (max 4)
|
||||||
# as the checksum from message into temp_crc_compare
|
# as the checksum from message into temp_crc_compare
|
||||||
crc_copy_len = payload_length - self.bytes_read - CHECKSUM_LENGTH
|
crc_copy_len = (
|
||||||
|
self.expected_payload_length - self.bytes_read - CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc)
|
self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc)
|
||||||
|
|
||||||
crc_checksum = buf[
|
crc_checksum = buf[
|
||||||
payload_length - CHECKSUM_LENGTH - len(buf) - self.bytes_read :
|
self.expected_payload_length
|
||||||
|
- CHECKSUM_LENGTH
|
||||||
|
- len(buf)
|
||||||
|
- self.bytes_read :
|
||||||
]
|
]
|
||||||
offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:])
|
offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:])
|
||||||
utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0)
|
utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0)
|
||||||
else:
|
else:
|
||||||
raise Exception("Buffer (+bytes_read) should not be bigger than payload")
|
raise Exception("Buffer (+bytes_read) should not be bigger than payload")
|
||||||
|
|
||||||
def _handle_fallback_decryption(self, payload_length: int, buf: memoryview):
|
def _handle_fallback_decryption(self, buf: memoryview) -> None:
|
||||||
if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH:
|
assert self.busy_decoder is not None
|
||||||
|
assert self.temp_tag is not None
|
||||||
|
if (
|
||||||
|
self.expected_payload_length
|
||||||
|
> len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH
|
||||||
|
):
|
||||||
# The noise tag is not in this packet, decrypt the whole buffer
|
# The noise tag is not in this packet, decrypt the whole buffer
|
||||||
self.busy_decoder.decrypt_part(buf)
|
self.busy_decoder.decrypt_part(buf)
|
||||||
elif payload_length >= len(buf) + self.bytes_read:
|
elif self.expected_payload_length >= len(buf) + self.bytes_read:
|
||||||
# At least a part of the CRC checksum is in this packet, compute CRC over
|
# At least a part of the CRC checksum is in this packet, compute CRC over
|
||||||
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes
|
# first (max(0, crc_copy_len)) bytes and add the rest of the bytes
|
||||||
# as the checksum from message into temp_crc_compare
|
# as the checksum from message into temp_crc_compare
|
||||||
dec_len = payload_length - self.bytes_read - TAG_LENGTH - CHECKSUM_LENGTH
|
dec_len = (
|
||||||
|
self.expected_payload_length
|
||||||
|
- self.bytes_read
|
||||||
|
- TAG_LENGTH
|
||||||
|
- CHECKSUM_LENGTH
|
||||||
|
)
|
||||||
self.busy_decoder.decrypt_part(buf[:dec_len])
|
self.busy_decoder.decrypt_part(buf[:dec_len])
|
||||||
|
|
||||||
noise_tag = buf[
|
noise_tag = buf[
|
||||||
payload_length
|
self.expected_payload_length
|
||||||
- CHECKSUM_LENGTH
|
- CHECKSUM_LENGTH
|
||||||
- TAG_LENGTH
|
- TAG_LENGTH
|
||||||
- len(buf)
|
- len(buf)
|
||||||
@ -293,7 +266,7 @@ class Channel:
|
|||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||||
except BufferError:
|
except WireBufferError:
|
||||||
self.set_channel_state(ChannelState.INVALIDATED)
|
self.set_channel_state(ChannelState.INVALIDATED)
|
||||||
pass # TODO handle device busy, channel kaput
|
pass # TODO handle device busy, channel kaput
|
||||||
self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH)
|
self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH)
|
||||||
@ -317,7 +290,8 @@ class Channel:
|
|||||||
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
|
# crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _prepare_busy_decoder(self) -> None:
|
def _prepare_fallback(self) -> None:
|
||||||
|
# prepare busy decoder
|
||||||
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
|
key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE)
|
||||||
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
|
nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE)
|
||||||
|
|
||||||
@ -326,11 +300,16 @@ class Channel:
|
|||||||
|
|
||||||
self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive)
|
self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive)
|
||||||
|
|
||||||
|
# prepare temp channel values
|
||||||
|
self.temp_crc = 0
|
||||||
|
self.temp_crc_compare = bytearray(4)
|
||||||
|
self.temp_tag = bytearray(16)
|
||||||
|
|
||||||
def decrypt_buffer(
|
def decrypt_buffer(
|
||||||
self, message_length: int, offset: int = INIT_HEADER_LENGTH
|
self, message_length: int, offset: int = INIT_HEADER_LENGTH
|
||||||
) -> None:
|
) -> None:
|
||||||
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||||
# if buffer is BufferError:
|
# if buffer is WireBufferError:
|
||||||
# pass # TODO handle deviceBUSY
|
# pass # TODO handle deviceBUSY
|
||||||
noise_buffer = memoryview(buffer)[
|
noise_buffer = memoryview(buffer)[
|
||||||
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
||||||
@ -388,7 +367,7 @@ class Channel:
|
|||||||
noise_payload_len = memory_manager.encode_into_buffer(
|
noise_payload_len = memory_manager.encode_into_buffer(
|
||||||
buffer, msg, session_id
|
buffer, msg, session_id
|
||||||
)
|
)
|
||||||
except BufferError:
|
except WireBufferError:
|
||||||
from trezor.messages import Failure, FailureType
|
from trezor.messages import Failure, FailureType
|
||||||
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
@ -424,7 +403,7 @@ class Channel:
|
|||||||
) -> Awaitable[None] | None:
|
) -> Awaitable[None] | None:
|
||||||
payload_length = len(payload)
|
payload_length = len(payload)
|
||||||
buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
|
buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
|
||||||
# if buffer is BufferError:
|
# if buffer is WireBufferError:
|
||||||
# pass # TODO handle deviceBUSY
|
# pass # TODO handle deviceBUSY
|
||||||
|
|
||||||
self._encrypt(buffer, payload_length)
|
self._encrypt(buffer, payload_length)
|
||||||
|
@ -3,6 +3,7 @@ from micropython import const
|
|||||||
|
|
||||||
from storage.cache_thp import SESSION_ID_LENGTH
|
from storage.cache_thp import SESSION_ID_LENGTH
|
||||||
from trezor import protobuf, utils
|
from trezor import protobuf, utils
|
||||||
|
from trezor.wire.errors import WireBufferError
|
||||||
from trezor.wire.message_handler import get_msg_type
|
from trezor.wire.message_handler import get_msg_type
|
||||||
|
|
||||||
from . import ThpError
|
from . import ThpError
|
||||||
@ -48,7 +49,7 @@ def get_existing_write_buffer(channel_id: int) -> memoryview:
|
|||||||
def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview:
|
def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview:
|
||||||
if is_locked():
|
if is_locked():
|
||||||
if not is_owner(channel_id):
|
if not is_owner(channel_id):
|
||||||
raise BufferError
|
raise WireBufferError
|
||||||
update_lock_time()
|
update_lock_time()
|
||||||
else:
|
else:
|
||||||
update_lock(channel_id)
|
update_lock(channel_id)
|
||||||
@ -80,19 +81,19 @@ def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryvie
|
|||||||
|
|
||||||
def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview:
|
def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview:
|
||||||
if not is_owner(channel_id):
|
if not is_owner(channel_id):
|
||||||
raise BufferError
|
raise WireBufferError
|
||||||
update_lock_time()
|
update_lock_time()
|
||||||
|
|
||||||
if buffer_type == _READ:
|
if buffer_type == _READ:
|
||||||
global READ_BUFFER_SLICE
|
global READ_BUFFER_SLICE
|
||||||
if READ_BUFFER_SLICE is None:
|
if READ_BUFFER_SLICE is None:
|
||||||
raise BufferError
|
raise WireBufferError
|
||||||
return READ_BUFFER_SLICE
|
return READ_BUFFER_SLICE
|
||||||
|
|
||||||
if buffer_type == _WRITE:
|
if buffer_type == _WRITE:
|
||||||
global WRITE_BUFFER_SLICE
|
global WRITE_BUFFER_SLICE
|
||||||
if WRITE_BUFFER_SLICE is None:
|
if WRITE_BUFFER_SLICE is None:
|
||||||
raise BufferError
|
raise WireBufferError
|
||||||
return WRITE_BUFFER_SLICE
|
return WRITE_BUFFER_SLICE
|
||||||
|
|
||||||
raise Exception("Invalid buffer_type")
|
raise Exception("Invalid buffer_type")
|
||||||
|
Loading…
Reference in New Issue
Block a user