1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

refactor(core): clean channel and received_message_handler

[no changelog]
This commit is contained in:
M1nd3r 2024-12-04 09:34:33 +01:00
parent 5c7f5edb80
commit 517707a1c2
2 changed files with 52 additions and 55 deletions

View File

@ -76,6 +76,7 @@ class Channel:
self.channel_cache.clear() self.channel_cache.clear()
# ACCESS TO CHANNEL_DATA # ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int: def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big") return int.from_bytes(self.channel_id, "big")
@ -100,7 +101,7 @@ class Channel:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("set_buffer: ", str(type(self.buffer))) self._log("set_buffer: ", str(type(self.buffer)))
# CALLED BY THP_MAIN_LOOP # READ and DECRYPT
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
@ -108,7 +109,7 @@ class Channel:
self._handle_received_packet(packet) self._handle_received_packet(packet)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("self.buffer: ", utils.get_bytes_as_str(self.buffer)) self._log("self.buffer: ", get_bytes_as_str(self.buffer))
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
self._finish_message() self._finish_message()
@ -166,6 +167,16 @@ class Channel:
raise ThpError("Continuation packet is not expected, ignoring") raise ThpError("Continuation packet is not expected, ignoring")
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
) -> None:
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self) -> None:
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
def _decrypt_single_packet_payload( def _decrypt_single_packet_payload(
self, payload: utils.BufferType self, payload: utils.BufferType
) -> utils.BufferType: ) -> utils.BufferType:
@ -212,42 +223,7 @@ class Channel:
if not is_tag_valid: if not is_tag_valid:
raise ThpDecryptionError() raise ThpDecryptionError()
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None: # WRITE and ENCRYPT
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("encrypt")
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
noise_buffer = memoryview(buffer)[0:noise_payload_len]
if utils.DISABLE_ENCRYPTION:
tag = crypto.DUMMY_TAG
else:
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
assert key_send is not None
assert nonce_send is not None
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("New nonce_send: ", str((nonce_send + 1)))
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
) -> None:
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self) -> None:
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write( async def write(
self, self,
@ -262,7 +238,7 @@ class Channel:
noise_payload_len = memory_manager.encode_into_buffer( noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id self.buffer, msg, session_id
) )
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force) task = self._write_and_encrypt(self.buffer[:noise_payload_len], force)
if task is not None: if task is not None:
await task await task
@ -272,7 +248,13 @@ class Channel:
header = PacketHeader.get_error_header(self.get_channel_id_int(), length) header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data) return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
def write_and_encrypt( def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _write_and_encrypt(
self, payload: bytes, force: bool = False self, payload: bytes, force: bool = False
) -> Awaitable[None] | None: ) -> Awaitable[None] | None:
payload_length = len(payload) payload_length = len(payload)
@ -297,12 +279,6 @@ class Channel:
) )
return None return None
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _prepare_write(self) -> None: def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false # TODO add condition that disallows to write when can_send_message is false
ABP.set_sending_allowed(self.channel_cache, False) ABP.set_sending_allowed(self.channel_cache, False)
@ -330,6 +306,31 @@ class Channel:
loop.clear() loop.clear()
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("encrypt")
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
noise_buffer = memoryview(buffer)[0:noise_payload_len]
if utils.DISABLE_ENCRYPTION:
tag = crypto.DUMMY_TAG
else:
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
assert key_send is not None
assert nonce_send is not None
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("New nonce_send: ", str((nonce_send + 1)))
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
def _can_clear_loop(self) -> bool: def _can_clear_loop(self) -> bool:
return ( return (
not workflow.tasks not workflow.tasks
@ -341,7 +342,7 @@ class Channel:
log.debug( log.debug(
__name__, __name__,
"(cid: %s) %s%s", "(cid: %s) %s%s",
utils.get_bytes_as_str(self.channel_id), get_bytes_as_str(self.channel_id),
text_1, text_1,
text_2, text_2,
) )

View File

@ -60,9 +60,7 @@ if TYPE_CHECKING:
from .channel import Channel from .channel import Channel
if __debug__: if __debug__:
from ubinascii import hexlify from trezor.utils import get_bytes_as_str
from . import state_to_str
_TREZOR_STATE_UNPAIRED = b"\x00" _TREZOR_STATE_UNPAIRED = b"\x00"
@ -198,8 +196,6 @@ def _handle_message_to_app_or_channel(
ctrl_byte: int, ctrl_byte: int,
) -> Awaitable[None]: ) -> Awaitable[None]:
state = ctx.get_channel_state() state = ctx.get_channel_state()
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT: if state is ChannelState.ENCRYPTED_TRANSPORT:
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
@ -244,14 +240,14 @@ async def _handle_state_TH1(
log.debug( log.debug(
__name__, __name__,
"trezor ephemeral pubkey: %s", "trezor ephemeral pubkey: %s",
hexlify(trezor_ephemeral_pubkey).decode(), get_bytes_as_str(trezor_ephemeral_pubkey),
) )
log.debug( log.debug(
__name__, __name__,
"encrypted trezor masked static pubkey: %s", "encrypted trezor masked static pubkey: %s",
hexlify(encrypted_trezor_static_pubkey).decode(), get_bytes_as_str(encrypted_trezor_static_pubkey),
) )
log.debug(__name__, "tag: %s", hexlify(tag)) log.debug(__name__, "tag: %s", get_bytes_as_str(tag))
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag