1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

refactor(core): improve readability and logging in channel.py

[no changelog]
This commit is contained in:
M1nd3r 2024-12-03 17:19:41 +01:00
parent daa05bc760
commit 5c7f5edb80

View File

@ -29,7 +29,7 @@ from .writer import (
)
if __debug__:
from ubinascii import hexlify
from trezor.utils import get_bytes_as_str
from . import state_to_str
@ -49,19 +49,27 @@ class Channel:
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "channel initialization")
# Channel properties
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
self.channel_cache: ChannelCache = channel_cache
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
# Shared variables
self.buffer: utils.BufferType
self.bytes_read: int = 0
self.expected_payload_length: int = 0
self.is_cont_packet_expected: bool = False
self.selected_pairing_methods = []
self.sessions: dict[int, GenericSessionContext] = {}
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
# Objects for writing a message to a wire
self.transmission_loop: TransmissionLoop | None = None
self.write_task_spawn: loop.spawn | None = None
# Temporary objects for handshake and pairing
self.handshake: crypto.Handshake | None = None
self.connection_context: PairingContext | None = None
def clear(self) -> None:
clear_sessions_with_channel_id(self.channel_id)
@ -74,12 +82,7 @@ class Channel:
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) get_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
self._log("get_channel_state: ", state_to_str(state))
return state
def get_handshake_hash(self) -> bytes:
@ -90,42 +93,22 @@ class Channel:
def set_channel_state(self, state: ChannelState) -> None:
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) set_channel_state: %s",
utils.get_bytes_as_str(self.channel_id),
state_to_str(state),
)
self._log("set_channel_state: ", state_to_str(state))
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) set_buffer: %s",
utils.get_bytes_as_str(self.channel_id),
type(self.buffer),
)
self._log("set_buffer: ", str(type(self.buffer)))
# CALLED BY THP_MAIN_LOOP
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) receive_packet",
utils.get_bytes_as_str(self.channel_id),
)
self._log("receive packet")
self._handle_received_packet(packet)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) self.buffer: %s",
utils.get_bytes_as_str(self.channel_id),
utils.get_bytes_as_str(self.buffer),
)
self._log("self.buffer: ", utils.get_bytes_as_str(self.buffer))
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
self._finish_message()
@ -146,13 +129,10 @@ class Channel:
def _handle_init_packet(self, packet: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_init_packet",
utils.get_bytes_as_str(self.channel_id),
)
# ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
_, _, payload_length = ustruct.unpack(">BHH", packet)
self._log("handle_init_packet")
# ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption
_, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet)
self.expected_payload_length = payload_length
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
@ -173,27 +153,15 @@ class Channel:
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_init_packet - payload len: %d",
utils.get_bytes_as_str(self.channel_id),
payload_length,
)
log.debug(
__name__,
"(cid: %s) handle_init_packet - buffer len: %d",
utils.get_bytes_as_str(self.channel_id),
len(self.buffer),
)
self._log("handle_init_packet - payload len: ", str(payload_length))
self._log("handle_init_packet - buffer len: ", str(len(self.buffer)))
return self._buffer_packet_data(self.buffer, packet, 0)
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) handle_cont_packet",
utils.get_bytes_as_str(self.channel_id),
)
self._log("handle_cont_packet")
if not self.is_cont_packet_expected:
raise ThpError("Continuation packet is not expected, ignoring")
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
@ -224,54 +192,30 @@ class Channel:
assert key_receive is not None
assert nonce_receive is not None
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Buffer before decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
self._log("Buffer before decryption: ", get_bytes_as_str(noise_buffer))
is_tag_valid = crypto.dec(
noise_buffer, tag, key_receive, nonce_receive, b""
)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Buffer after decryption: %s",
utils.get_bytes_as_str(self.channel_id),
hexlify(noise_buffer),
)
self._log("Buffer after decryption: ", get_bytes_as_str(noise_buffer))
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) Is decrypted tag valid? %s",
utils.get_bytes_as_str(self.channel_id),
str(is_tag_valid),
)
log.debug(
__name__,
"(cid: %s) Received tag: %s",
utils.get_bytes_as_str(self.channel_id),
(hexlify(tag).decode()),
)
log.debug(
__name__,
"(cid: %s) New nonce_receive: %i",
utils.get_bytes_as_str(self.channel_id),
nonce_receive + 1,
)
self._log("Is decrypted tag valid? ", str(is_tag_valid))
self._log("Received tag: ", get_bytes_as_str(tag))
self._log("New nonce_receive: ", str((nonce_receive + 1)))
if not is_tag_valid:
raise ThpDecryptionError()
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
)
self._log("encrypt")
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
noise_buffer = memoryview(buffer)[0:noise_payload_len]
@ -289,7 +233,7 @@ class Channel:
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
self._log("New nonce_send: ", str((nonce_send + 1)))
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
@ -312,13 +256,7 @@ class Channel:
force: bool = False,
) -> None:
if __debug__ and utils.EMULATOR:
log.debug(
__name__,
"(cid: %s) write message: %s\n%s",
utils.get_bytes_as_str(self.channel_id),
msg.MESSAGE_NAME,
utils.dump_protobuf(msg),
)
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
noise_payload_len = memory_manager.encode_into_buffer(
@ -347,9 +285,8 @@ class Channel:
self._prepare_write()
if force:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__, "Writing FORCE message (without async or retransmission)."
)
self._log("Writing FORCE message (without async or retransmission).")
return self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
@ -374,11 +311,8 @@ class Channel:
self, ctrl_byte: int, payload: bytes
) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid %s) write_encrypted_payload_loop",
utils.get_bytes_as_str(self.channel_id),
)
self._log("write_encrypted_payload_loop")
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = ABP.get_send_seq_bit(self.channel_cache)
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
@ -392,14 +326,22 @@ class Channel:
# workflow and the state is ENCRYPTED_TRANSPORT
if self._can_clear_loop():
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"(cid: %s) clearing loop from channel",
utils.get_bytes_as_str(self.channel_id),
)
self._log("clearing loop from channel")
loop.clear()
def _can_clear_loop(self) -> bool:
return (
not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
if __debug__:
def _log(self, text_1: str, text_2: str = "") -> None:
log.debug(
__name__,
"(cid: %s) %s%s",
utils.get_bytes_as_str(self.channel_id),
text_1,
text_2,
)