mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-28 15:22:14 +00:00
feat(core): add a flag to disable most of debug logging
[no changelog]
This commit is contained in:
parent
c241adfc4d
commit
38fbfc7a61
@ -37,6 +37,8 @@ DISABLE_ANIMATION = 0
|
|||||||
|
|
||||||
DISABLE_ENCRYPTION: bool = False
|
DISABLE_ENCRYPTION: bool = False
|
||||||
|
|
||||||
|
ALLOW_DEBUG_MESSAGES: bool = False
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
if EMULATOR:
|
if EMULATOR:
|
||||||
import uos
|
import uos
|
||||||
|
@ -30,7 +30,7 @@ def wrap_protobuf_load(
|
|||||||
expected_type: type[LoadedMessageType],
|
expected_type: type[LoadedMessageType],
|
||||||
) -> LoadedMessageType:
|
) -> LoadedMessageType:
|
||||||
try:
|
try:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"Buffer to be parsed to a LoadedMessage: %s",
|
"Buffer to be parsed to a LoadedMessage: %s",
|
||||||
@ -43,7 +43,7 @@ def wrap_protobuf_load(
|
|||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.exception(__name__, e)
|
log.exception(__name__, e)
|
||||||
if e.args:
|
if e.args:
|
||||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||||
@ -94,7 +94,7 @@ async def handle_single_message(
|
|||||||
the type of message is supposed to be optimized and not disrupt the running state,
|
the type of message is supposed to be optimized and not disrupt the running state,
|
||||||
this function will return `True`.
|
this function will return `True`.
|
||||||
"""
|
"""
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
try:
|
try:
|
||||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -190,7 +190,7 @@ async def handle_single_message(
|
|||||||
# - the message was not valid protobuf
|
# - the message was not valid protobuf
|
||||||
# - workflow raised some kind of an exception while running
|
# - workflow raised some kind of an exception while running
|
||||||
# - something canceled the workflow from the outside
|
# - something canceled the workflow from the outside
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
if isinstance(exc, ActionCancelled):
|
if isinstance(exc, ActionCancelled):
|
||||||
log.debug(__name__, "cancelled: %s", exc.message)
|
log.debug(__name__, "cancelled: %s", exc.message)
|
||||||
elif isinstance(exc, loop.TaskClosed):
|
elif isinstance(exc, loop.TaskClosed):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from storage.cache_thp import ChannelCache
|
from storage.cache_thp import ChannelCache
|
||||||
from trezor import log
|
from trezor import log, utils
|
||||||
from trezor.wire.thp import ThpError
|
from trezor.wire.thp import ThpError
|
||||||
|
|
||||||
|
|
||||||
@ -20,14 +20,14 @@ def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
|
|||||||
|
|
||||||
def _is_ack_expected(cache: ChannelCache) -> bool:
|
def _is_ack_expected(cache: ChannelCache) -> bool:
|
||||||
is_expected: bool = not is_sending_allowed(cache)
|
is_expected: bool = not is_sending_allowed(cache)
|
||||||
if __debug__ and not is_expected:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_expected:
|
||||||
log.debug(__name__, "Received unexpected ACK message")
|
log.debug(__name__, "Received unexpected ACK message")
|
||||||
return is_expected
|
return is_expected
|
||||||
|
|
||||||
|
|
||||||
def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
|
def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
|
||||||
is_correct: bool = get_send_seq_bit(cache) == sync_bit
|
is_correct: bool = get_send_seq_bit(cache) == sync_bit
|
||||||
if __debug__ and not is_correct:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES and not is_correct:
|
||||||
log.debug(__name__, "Received ACK message with wrong ack bit")
|
log.debug(__name__, "Received ACK message with wrong ack bit")
|
||||||
return is_correct
|
return is_correct
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class Channel:
|
|||||||
|
|
||||||
def get_channel_state(self) -> int:
|
def get_channel_state(self) -> int:
|
||||||
state = int.from_bytes(self.channel_cache.state, "big")
|
state = int.from_bytes(self.channel_cache.state, "big")
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) get_channel_state: %s",
|
"(cid: %s) get_channel_state: %s",
|
||||||
@ -86,7 +86,7 @@ class Channel:
|
|||||||
|
|
||||||
def set_channel_state(self, state: ChannelState) -> None:
|
def set_channel_state(self, state: ChannelState) -> None:
|
||||||
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) set_channel_state: %s",
|
"(cid: %s) set_channel_state: %s",
|
||||||
@ -96,7 +96,7 @@ class Channel:
|
|||||||
|
|
||||||
def set_buffer(self, buffer: utils.BufferType) -> None:
|
def set_buffer(self, buffer: utils.BufferType) -> None:
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) set_buffer: %s",
|
"(cid: %s) set_buffer: %s",
|
||||||
@ -107,7 +107,7 @@ class Channel:
|
|||||||
# CALLED BY THP_MAIN_LOOP
|
# CALLED BY THP_MAIN_LOOP
|
||||||
|
|
||||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) receive_packet",
|
"(cid: %s) receive_packet",
|
||||||
@ -116,7 +116,7 @@ class Channel:
|
|||||||
|
|
||||||
self._handle_received_packet(packet)
|
self._handle_received_packet(packet)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) self.buffer: %s",
|
"(cid: %s) self.buffer: %s",
|
||||||
@ -142,7 +142,7 @@ class Channel:
|
|||||||
return self._handle_init_packet(packet)
|
return self._handle_init_packet(packet)
|
||||||
|
|
||||||
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) handle_init_packet",
|
"(cid: %s) handle_init_packet",
|
||||||
@ -169,7 +169,7 @@ class Channel:
|
|||||||
payload_length,
|
payload_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) handle_init_packet - payload len: %d",
|
"(cid: %s) handle_init_packet - payload len: %d",
|
||||||
@ -185,7 +185,7 @@ class Channel:
|
|||||||
return self._buffer_packet_data(self.buffer, packet, 0)
|
return self._buffer_packet_data(self.buffer, packet, 0)
|
||||||
|
|
||||||
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) handle_cont_packet",
|
"(cid: %s) handle_cont_packet",
|
||||||
@ -221,7 +221,7 @@ class Channel:
|
|||||||
|
|
||||||
assert key_receive is not None
|
assert key_receive is not None
|
||||||
assert nonce_receive is not None
|
assert nonce_receive is not None
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) Buffer before decryption: %s",
|
"(cid: %s) Buffer before decryption: %s",
|
||||||
@ -231,7 +231,7 @@ class Channel:
|
|||||||
is_tag_valid = crypto.dec(
|
is_tag_valid = crypto.dec(
|
||||||
noise_buffer, tag, key_receive, nonce_receive, b""
|
noise_buffer, tag, key_receive, nonce_receive, b""
|
||||||
)
|
)
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) Buffer after decryption: %s",
|
"(cid: %s) Buffer after decryption: %s",
|
||||||
@ -241,7 +241,7 @@ class Channel:
|
|||||||
|
|
||||||
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
|
self.channel_cache.set_int(CHANNEL_NONCE_RECEIVE, nonce_receive + 1)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) Is decrypted tag valid? %s",
|
"(cid: %s) Is decrypted tag valid? %s",
|
||||||
@ -265,7 +265,7 @@ class Channel:
|
|||||||
raise ThpDecryptionError()
|
raise ThpDecryptionError()
|
||||||
|
|
||||||
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
|
__name__, "(cid: %s) encrypt", utils.get_bytes_as_str(self.channel_id)
|
||||||
)
|
)
|
||||||
@ -285,7 +285,7 @@ class Channel:
|
|||||||
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
|
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
|
||||||
|
|
||||||
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
|
log.debug(__name__, "New nonce_send: %i", nonce_send + 1)
|
||||||
|
|
||||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||||
@ -343,7 +343,7 @@ class Channel:
|
|||||||
print("\nCLOSED\n")
|
print("\nCLOSED\n")
|
||||||
self._prepare_write()
|
self._prepare_write()
|
||||||
if force:
|
if force:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "Writing FORCE message (without async or retransmission)."
|
__name__, "Writing FORCE message (without async or retransmission)."
|
||||||
)
|
)
|
||||||
@ -370,7 +370,7 @@ class Channel:
|
|||||||
async def _write_encrypted_payload_loop(
|
async def _write_encrypted_payload_loop(
|
||||||
self, ctrl_byte: int, payload: bytes
|
self, ctrl_byte: int, payload: bytes
|
||||||
) -> None:
|
) -> None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid %s) write_encrypted_payload_loop",
|
"(cid %s) write_encrypted_payload_loop",
|
||||||
@ -388,7 +388,7 @@ class Channel:
|
|||||||
# Let the main loop be restarted and clear loop, if there is no other
|
# Let the main loop be restarted and clear loop, if there is no other
|
||||||
# workflow and the state is ENCRYPTED_TRANSPORT
|
# workflow and the state is ENCRYPTED_TRANSPORT
|
||||||
if self._can_clear_loop():
|
if self._can_clear_loop():
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) clearing loop from channel",
|
"(cid: %s) clearing loop from channel",
|
||||||
|
@ -36,7 +36,7 @@ def select_buffer(
|
|||||||
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
|
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
|
||||||
return buffer
|
return buffer
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.exception(__name__, e)
|
log.exception(__name__, e)
|
||||||
raise Exception("Failed to create a buffer for channel") # TODO handle better
|
raise Exception("Failed to create a buffer for channel") # TODO handle better
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ def _get_buffer_for_read(
|
|||||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||||
) -> utils.BufferType:
|
) -> utils.BufferType:
|
||||||
length = payload_length + INIT_HEADER_LENGTH
|
length = payload_length + INIT_HEADER_LENGTH
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"get_buffer_for_read - length: %d, %s %s",
|
"get_buffer_for_read - length: %d, %s %s",
|
||||||
@ -110,13 +110,13 @@ def _get_buffer_for_read(
|
|||||||
raise ThpError("Message too large")
|
raise ThpError("Message too large")
|
||||||
|
|
||||||
if length > len(existing_buffer):
|
if length > len(existing_buffer):
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Allocating a new buffer")
|
log.debug(__name__, "Allocating a new buffer")
|
||||||
|
|
||||||
from ..thp_main import get_raw_read_buffer
|
from ..thp_main import get_raw_read_buffer
|
||||||
|
|
||||||
if length > len(get_raw_read_buffer()):
|
if length > len(get_raw_read_buffer()):
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"Required length is %d, where raw buffer has capacity only %d",
|
"Required length is %d, where raw buffer has capacity only %d",
|
||||||
@ -133,7 +133,7 @@ def _get_buffer_for_read(
|
|||||||
return payload
|
return payload
|
||||||
|
|
||||||
# reuse a part of the supplied buffer
|
# reuse a part of the supplied buffer
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Reusing already allocated buffer")
|
log.debug(__name__, "Reusing already allocated buffer")
|
||||||
return memoryview(existing_buffer)[:length]
|
return memoryview(existing_buffer)[:length]
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ def _get_buffer_for_write(
|
|||||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||||
) -> utils.BufferType:
|
) -> utils.BufferType:
|
||||||
length = payload_length + INIT_HEADER_LENGTH
|
length = payload_length + INIT_HEADER_LENGTH
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"get_buffer_for_write - length: %d, %s %s",
|
"get_buffer_for_write - length: %d, %s %s",
|
||||||
@ -154,7 +154,7 @@ def _get_buffer_for_write(
|
|||||||
raise ThpError("Message too large")
|
raise ThpError("Message too large")
|
||||||
|
|
||||||
if length > len(existing_buffer):
|
if length > len(existing_buffer):
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Creating a new write buffer from raw write buffer")
|
log.debug(__name__, "Creating a new write buffer from raw write buffer")
|
||||||
|
|
||||||
from ..thp_main import get_raw_write_buffer
|
from ..thp_main import get_raw_write_buffer
|
||||||
@ -170,6 +170,6 @@ def _get_buffer_for_write(
|
|||||||
return payload
|
return payload
|
||||||
|
|
||||||
# reuse a part of the supplied buffer
|
# reuse a part of the supplied buffer
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Reusing already allocated buffer")
|
log.debug(__name__, "Reusing already allocated buffer")
|
||||||
return memoryview(existing_buffer)[:length]
|
return memoryview(existing_buffer)[:length]
|
||||||
|
@ -66,18 +66,19 @@ async def handle_received_message(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a message received from the channel."""
|
"""Handle a message received from the channel."""
|
||||||
|
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "handle_received_message")
|
log.debug(__name__, "handle_received_message")
|
||||||
|
if utils.ALLOW_DEBUG_MESSAGES: # TODO remove after performance tests are done
|
||||||
try:
|
try:
|
||||||
import micropython
|
import micropython
|
||||||
|
|
||||||
print("micropython.mem_info() from received_message_handler.py")
|
print("micropython.mem_info() from received_message_handler.py")
|
||||||
micropython.mem_info()
|
micropython.mem_info()
|
||||||
print(
|
print("Allocation count:", micropython.alloc_count())
|
||||||
"Allocation count:", micropython.alloc_count() # type: ignore ["alloc_count" is not a known attribute of module "micropython"]
|
|
||||||
)
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
print("To show allocation count, create the build with TREZOR_MEMPERF=1")
|
print(
|
||||||
|
"To show allocation count, create the build with TREZOR_MEMPERF=1"
|
||||||
|
)
|
||||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
|
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
|
||||||
message_length = payload_length + INIT_HEADER_LENGTH
|
message_length = payload_length + INIT_HEADER_LENGTH
|
||||||
|
|
||||||
@ -86,7 +87,7 @@ async def handle_received_message(
|
|||||||
# Synchronization process
|
# Synchronization process
|
||||||
seq_bit = (ctrl_byte & 0x10) >> 4
|
seq_bit = (ctrl_byte & 0x10) >> 4
|
||||||
ack_bit = (ctrl_byte & 0x08) >> 3
|
ack_bit = (ctrl_byte & 0x08) >> 3
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"handle_completed_message - seq bit of message: %d, ack bit of message: %d",
|
"handle_completed_message - seq bit of message: %d, ack bit of message: %d",
|
||||||
@ -108,7 +109,7 @@ async def handle_received_message(
|
|||||||
|
|
||||||
# 2: Handle message with unexpected sequential bit
|
# 2: Handle message with unexpected sequential bit
|
||||||
if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
|
if seq_bit != ABP.get_expected_receive_seq_bit(ctx.channel_cache):
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Received message with an unexpected sequential bit")
|
log.debug(__name__, "Received message with an unexpected sequential bit")
|
||||||
await _send_ack(ctx, ack_bit=seq_bit)
|
await _send_ack(ctx, ack_bit=seq_bit)
|
||||||
raise ThpError("Received message with an unexpected sequential bit")
|
raise ThpError("Received message with an unexpected sequential bit")
|
||||||
@ -131,14 +132,14 @@ async def handle_received_message(
|
|||||||
except ThpInvalidDataError:
|
except ThpInvalidDataError:
|
||||||
await ctx.write_error(ThpErrorType.INVALID_DATA)
|
await ctx.write_error(ThpErrorType.INVALID_DATA)
|
||||||
ctx.clear()
|
ctx.clear()
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "handle_received_message - end")
|
log.debug(__name__, "handle_received_message - end")
|
||||||
|
|
||||||
|
|
||||||
def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
|
def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
|
||||||
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
||||||
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
|
header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"Writing ACK message to a channel with id: %d, ack_bit: %d",
|
"Writing ACK message to a channel with id: %d, ack_bit: %d",
|
||||||
@ -149,13 +150,13 @@ def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]:
|
|||||||
|
|
||||||
|
|
||||||
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
|
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "check_checksum")
|
log.debug(__name__, "check_checksum")
|
||||||
if not checksum.is_valid(
|
if not checksum.is_valid(
|
||||||
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
|
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
|
||||||
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
|
data=memoryview(message_buffer)[: message_length - CHECKSUM_LENGTH],
|
||||||
):
|
):
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Invalid checksum, ignoring message.")
|
log.debug(__name__, "Invalid checksum, ignoring message.")
|
||||||
raise ThpError("Invalid checksum, ignoring message.")
|
raise ThpError("Invalid checksum, ignoring message.")
|
||||||
|
|
||||||
@ -164,17 +165,17 @@ async def _handle_ack(ctx: Channel, ack_bit: int):
|
|||||||
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
|
if not ABP.is_ack_valid(ctx.channel_cache, ack_bit):
|
||||||
return
|
return
|
||||||
# ACK is expected and it has correct sync bit
|
# ACK is expected and it has correct sync bit
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Received ACK message with correct ack bit")
|
log.debug(__name__, "Received ACK message with correct ack bit")
|
||||||
if ctx.transmission_loop is not None:
|
if ctx.transmission_loop is not None:
|
||||||
ctx.transmission_loop.stop_immediately()
|
ctx.transmission_loop.stop_immediately()
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Stopped transmission loop")
|
log.debug(__name__, "Stopped transmission loop")
|
||||||
|
|
||||||
ABP.set_sending_allowed(ctx.channel_cache, True)
|
ABP.set_sending_allowed(ctx.channel_cache, True)
|
||||||
|
|
||||||
if ctx.write_task_spawn is not None:
|
if ctx.write_task_spawn is not None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
|
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
|
||||||
await ctx.write_task_spawn
|
await ctx.write_task_spawn
|
||||||
# Note that no the write_task_spawn could result in loop.clear(),
|
# Note that no the write_task_spawn could result in loop.clear(),
|
||||||
@ -189,7 +190,7 @@ def _handle_message_to_app_or_channel(
|
|||||||
ctrl_byte: int,
|
ctrl_byte: int,
|
||||||
) -> Awaitable[None]:
|
) -> Awaitable[None]:
|
||||||
state = ctx.get_channel_state()
|
state = ctx.get_channel_state()
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "state: %s", state_to_str(state))
|
log.debug(__name__, "state: %s", state_to_str(state))
|
||||||
|
|
||||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||||
@ -213,7 +214,7 @@ async def _handle_state_TH1(
|
|||||||
message_length: int,
|
message_length: int,
|
||||||
ctrl_byte: int,
|
ctrl_byte: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "handle_state_TH1")
|
log.debug(__name__, "handle_state_TH1")
|
||||||
if not control_byte.is_handshake_init_req(ctrl_byte):
|
if not control_byte.is_handshake_init_req(ctrl_byte):
|
||||||
raise ThpError("Message received is not a handshake init request!")
|
raise ThpError("Message received is not a handshake init request!")
|
||||||
@ -231,7 +232,7 @@ async def _handle_state_TH1(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"trezor ephemeral pubkey: %s",
|
"trezor ephemeral pubkey: %s",
|
||||||
@ -255,7 +256,7 @@ async def _handle_state_TH1(
|
|||||||
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
|
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
|
||||||
from apps.thp.credential_manager import validate_credential
|
from apps.thp.credential_manager import validate_credential
|
||||||
|
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "handle_state_TH2")
|
log.debug(__name__, "handle_state_TH2")
|
||||||
if not control_byte.is_handshake_comp_req(ctrl_byte):
|
if not control_byte.is_handshake_comp_req(ctrl_byte):
|
||||||
raise ThpError("Message received is not a handshake completion request!")
|
raise ThpError("Message received is not a handshake completion request!")
|
||||||
@ -298,7 +299,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
raise ThpInvalidDataError()
|
raise ThpInvalidDataError()
|
||||||
if method not in ctx.selected_pairing_methods:
|
if method not in ctx.selected_pairing_methods:
|
||||||
ctx.selected_pairing_methods.append(method)
|
ctx.selected_pairing_methods.append(method)
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"host static pubkey: %s, noise payload: %s",
|
"host static pubkey: %s, noise payload: %s",
|
||||||
@ -318,7 +319,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
host_static_pubkey,
|
host_static_pubkey,
|
||||||
)
|
)
|
||||||
except DataError as e:
|
except DataError as e:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.exception(__name__, e)
|
log.exception(__name__, e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -340,7 +341,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
|
async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -> None:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
||||||
|
|
||||||
ctx.decrypt_buffer(message_length)
|
ctx.decrypt_buffer(message_length)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
|
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
|
||||||
from trezor import log, loop, protobuf
|
from trezor import log, loop, protobuf, utils
|
||||||
from trezor.wire import message_handler, protocol_common
|
from trezor.wire import message_handler, protocol_common
|
||||||
from trezor.wire.context import UnexpectedMessageException
|
from trezor.wire.context import UnexpectedMessageException
|
||||||
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure, find_handler
|
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure, find_handler
|
||||||
@ -76,7 +76,7 @@ class GenericSessionContext(Context):
|
|||||||
try:
|
try:
|
||||||
message = await self._get_message(self.incoming_message, next_message)
|
message = await self._get_message(self.incoming_message, next_message)
|
||||||
except protocol_common.WireError as e:
|
except protocol_common.WireError as e:
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.exception(__name__, e)
|
log.exception(__name__, e)
|
||||||
await self.write(failure(e))
|
await self.write(failure(e))
|
||||||
return _REPEAT_LOOP
|
return _REPEAT_LOOP
|
||||||
|
@ -69,7 +69,7 @@ async def write_payloads_to_wire(
|
|||||||
packet_offset = CONT_HEADER_LENGTH
|
packet_offset = CONT_HEADER_LENGTH
|
||||||
|
|
||||||
# write packet to wire (in-lined)
|
# write packet to wire (in-lined)
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||||
)
|
)
|
||||||
@ -82,7 +82,7 @@ async def write_payloads_to_wire(
|
|||||||
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None:
|
||||||
while True:
|
while True:
|
||||||
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
await loop.wait(iface.iface_num() | io.POLL_WRITE)
|
||||||
if __debug__:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
__name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user