mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-04 13:38:28 +00:00
refactor(core): improve readability and logging in channel.py
[no changelog]
This commit is contained in:
parent
c019e65631
commit
f6ea5ea630
@ -49,19 +49,27 @@ class Channel:
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "channel initialization")
|
||||
|
||||
# Channel properties
|
||||
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
|
||||
self.channel_cache: ChannelCache = channel_cache
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.expected_payload_length: int = 0
|
||||
self.bytes_read: int = 0
|
||||
self.buffer: utils.BufferType
|
||||
self.channel_id: bytes = channel_cache.channel_id
|
||||
|
||||
# Shared variables
|
||||
self.buffer: utils.BufferType
|
||||
self.bytes_read: int = 0
|
||||
self.expected_payload_length: int = 0
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.selected_pairing_methods = []
|
||||
self.sessions: dict[int, GenericSessionContext] = {}
|
||||
self.write_task_spawn: loop.spawn | None = None
|
||||
self.connection_context: PairingContext | None = None
|
||||
|
||||
# Objects for writing a message to a wire
|
||||
self.transmission_loop: TransmissionLoop | None = None
|
||||
self.write_task_spawn: loop.spawn | None = None
|
||||
|
||||
# Temporary objects for handshake and pairing
|
||||
self.handshake: crypto.Handshake | None = None
|
||||
self.connection_context: PairingContext | None = None
|
||||
|
||||
def clear(self) -> None:
|
||||
clear_sessions_with_channel_id(self.channel_id)
|
||||
@ -74,12 +82,7 @@ class Channel:
|
||||
def get_channel_state(self) -> int:
|
||||
state = int.from_bytes(self.channel_cache.state, "big")
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) get_channel_state: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
state_to_str(state),
|
||||
)
|
||||
self._log("get_channel_state: ", state_to_str(state))
|
||||
return state
|
||||
|
||||
def get_handshake_hash(self) -> bytes:
|
||||
@ -90,42 +93,22 @@ class Channel:
|
||||
def set_channel_state(self, state: ChannelState) -> None:
|
||||
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) set_channel_state: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
state_to_str(state),
|
||||
)
|
||||
self._log("set_channel_state: ", state_to_str(state))
|
||||
|
||||
def set_buffer(self, buffer: utils.BufferType) -> None:
|
||||
self.buffer = buffer
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) set_buffer: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
type(self.buffer),
|
||||
)
|
||||
self._log("set_buffer: ", str(type(self.buffer)))
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) receive_packet",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
|
||||
self._log("receive packet")
|
||||
self._handle_received_packet(packet)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) self.buffer: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
utils.get_bytes_as_str(self.buffer),
|
||||
)
|
||||
self._log("self.buffer: ", utils.get_bytes_as_str(self.buffer))
|
||||
|
||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||
self._finish_message()
|
||||
@ -146,13 +129,10 @@ class Channel:
|
||||
|
||||
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_init_packet",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
# ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) # TODO use this with single packet decryption
|
||||
_, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||
self._log("handle_init_packet")
|
||||
|
||||
# ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption
|
||||
_, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet)
|
||||
self.expected_payload_length = payload_length
|
||||
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
|
||||
|
||||
@ -173,27 +153,15 @@ class Channel:
|
||||
)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_init_packet - payload len: %d",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
payload_length,
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_init_packet - buffer len: %d",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
len(self.buffer),
|
||||
)
|
||||
self._log("handle_init_packet - payload len: ", str(payload_length))
|
||||
self._log("handle_init_packet - buffer len: ", str(len(self.buffer)))
|
||||
|
||||
return self._buffer_packet_data(self.buffer, packet, 0)
|
||||
|
||||
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) handle_cont_packet",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
self._log("handle_cont_packet")
|
||||
|
||||
if not self.is_cont_packet_expected:
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
||||
@ -224,54 +192,30 @@ class Channel:
|
||||
|
||||
assert key_receive is not None
|
||||
assert nonce_receive is not None
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Buffer before decryption: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
hexlify(noise_buffer),
|
||||
)
|
||||
self._log("Buffer before decryption: ", hexlify(noise_buffer))
|
||||
|
||||
is_tag_valid = crypto.dec(
|
||||
noise_buffer, tag, key_receive, nonce_receive, b""
|
||||
)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Buffer after decryption: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
hexlify(noise_buffer),
|
||||
)
|
||||
self._log("Buffer after decryption: ", hexlify(noise_buffer))
|
||||
|
||||
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Is decrypted tag valid? %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
str(is_tag_valid),
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) Received tag: %s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
(hexlify(tag).decode()),
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) New nonce_receive: %i",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
nonce_receive + 1,
|
||||
)
|
||||
self._log("Is decrypted tag valid? ", str(is_tag_valid))
|
||||
self._log("Received tag: ", hexlify(tag).decode())
|
||||
self._log("New nonce_receive: ", str((nonce_receive + 1)))
|
||||
|
||||
if not is_tag_valid:
|
||||
raise ThpDecryptionError()
|
||||
|
||||
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
|
||||
)
|
||||
self._log("encrypt")
|
||||
|
||||
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
|
||||
noise_buffer = memoryview(buffer)[0:noise_payload_len]
|
||||
@ -289,7 +233,7 @@ class Channel:
|
||||
|
||||
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
|
||||
self._log("New nonce_send: ", str((nonce_send + 1)))
|
||||
|
||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||
|
||||
@ -312,13 +256,7 @@ class Channel:
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) write message: %s\n%s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
msg.MESSAGE_NAME,
|
||||
utils.dump_protobuf(msg),
|
||||
)
|
||||
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
|
||||
|
||||
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
|
||||
noise_payload_len = memory_manager.encode_into_buffer(
|
||||
@ -347,9 +285,8 @@ class Channel:
|
||||
self._prepare_write()
|
||||
if force:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__, "Writing FORCE message (without async or retransmission)."
|
||||
)
|
||||
self._log("Writing FORCE message (without async or retransmission).")
|
||||
|
||||
return self._write_encrypted_payload_loop(
|
||||
ENCRYPTED, memoryview(self.buffer[:payload_length])
|
||||
)
|
||||
@ -374,11 +311,8 @@ class Channel:
|
||||
self, ctrl_byte: int, payload: bytes
|
||||
) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid %s) write_encrypted_payload_loop",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
self._log("write_encrypted_payload_loop")
|
||||
|
||||
payload_len = len(payload) + CHECKSUM_LENGTH
|
||||
sync_bit = ABP.get_send_seq_bit(self.channel_cache)
|
||||
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
|
||||
@ -392,14 +326,22 @@ class Channel:
|
||||
# workflow and the state is ENCRYPTED_TRANSPORT
|
||||
if self._can_clear_loop():
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) clearing loop from channel",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
)
|
||||
self._log("clearing loop from channel")
|
||||
|
||||
loop.clear()
|
||||
|
||||
def _can_clear_loop(self) -> bool:
|
||||
return (
|
||||
not workflow.tasks
|
||||
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
||||
|
||||
if __debug__:
|
||||
|
||||
def _log(self, text_1: str, text_2: str = "") -> None:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) %s%s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
text_1,
|
||||
text_2,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user