Improve logging, partially refactror channel.py

M1nd3r/thp2
M1nd3r 4 weeks ago
parent 46d46b1845
commit 9869b42ce5

@ -68,7 +68,6 @@ class Channel(Context):
from trezor.wire.thp.session_context import load_cached_sessions
self.connection_context = None
self.sessions = load_cached_sessions(self)
@classmethod
@ -107,13 +106,12 @@ class Channel(Context):
async def receive_packet(self, packet: utils.BufferType):
if __debug__:
log.debug(__name__, "receive_packet")
ctrl_byte = packet[0]
if _is_ctrl_byte_continuation(ctrl_byte):
await self._handle_cont_packet(packet)
else:
await self._handle_init_packet(packet)
await self._handle_received_packet(packet)
if __debug__:
log.debug(__name__, "self.buffer: %s", utils.get_bytes_as_str(self.buffer))
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
self._finish_message()
await self._handle_completed_message()
@ -124,7 +122,14 @@ class Channel(Context):
"Read more bytes than is the expected length of the message, this should not happen!"
)
async def _handle_init_packet(self, packet: utils.BufferType):
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
ctrl_byte = packet[0]
if _is_ctrl_byte_continuation(ctrl_byte):
await self._handle_cont_packet(packet)
else:
await self._handle_init_packet(packet)
async def _handle_init_packet(self, packet: utils.BufferType) -> None:
if __debug__:
log.debug(__name__, "handle_init_packet")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
@ -135,6 +140,16 @@ class Channel(Context):
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
packet_payload = self._decrypt_single_packet_payload(packet_payload)
self._select_buffer(packet_payload, payload_length)
await self._buffer_packet_data(self.buffer, packet, 0)
if __debug__:
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
def _select_buffer(
self, packet_payload: utils.BufferType, payload_length: int
) -> None:
state = self.get_channel_state()
if state is ChannelState.ENCRYPTED_TRANSPORT:
@ -157,16 +172,8 @@ class Channel(Context):
except Exception as e:
if __debug__:
log.exception(__name__, e)
if __debug__:
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
if __debug__:
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
await self._buffer_packet_data(self.buffer, packet, 0)
if __debug__:
log.debug(__name__, "handle_init_packet - end")
async def _handle_cont_packet(self, packet: utils.BufferType):
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__:
log.debug(__name__, "handle_cont_packet")
if not self.is_cont_packet_expected:
@ -257,9 +264,11 @@ class Channel(Context):
if state is ChannelState.TH2:
await self._handle_state_TH2(message_length, ctrl_byte, sync_bit)
return
if is_channel_state_pairing(state):
await self._handle_pairing(message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
@ -314,7 +323,7 @@ class Channel(Context):
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert isinstance(noise_payload, ThpHandshakeCompletionReqNoisePayload)
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
for i in noise_payload.pairing_methods:
self.selected_pairing_methods.append(i)
if __debug__:
@ -325,6 +334,7 @@ class Channel(Context):
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# TODO add credential recognition
paired: bool = False # TODO should be output from credential check
# send hanshake completion response
@ -334,7 +344,6 @@ class Channel(Context):
thp_messages.get_handshake_completion_response(paired=paired),
)
)
# TODO add credential recognition
if paired:
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
@ -343,6 +352,7 @@ class Channel(Context):
async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
self._decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
@ -434,6 +444,8 @@ class Channel(Context):
response_message = await task
# TODO handle
await self.write(response_message)
if __debug__:
log.debug(__name__, "_handle_channel_message - end")
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
payload_buffer = bytearray(payload)
@ -536,6 +548,8 @@ class Channel(Context):
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
)
)
if __debug__:
log.debug(__name__, "Scheduled _write_encrypted_payload_loop")
async def _write_encrypted_payload_loop(
self, ctrl_byte: int, payload: bytes
@ -574,6 +588,8 @@ class Channel(Context):
not workflow.tasks
and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
):
if __debug__:
log.debug(__name__, "Clearing loop from channel")
loop.clear()
async def _wait_for_ack(self) -> None:
@ -653,10 +669,11 @@ def _get_buffer_for_message(
) -> utils.BufferType:
length = payload_length + INIT_DATA_OFFSET
if __debug__:
log.debug(__name__, "get_buffer_for_message - length: %d", length)
log.debug(
__name__,
"get_buffer_for_message - existing buffer type: %s",
"get_buffer_for_message - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:

@ -51,7 +51,6 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
# following bytes are not "##"", do not respond
if cid == BROADCAST_CHANNEL_ID:
# TODO handle exceptions, try-catch?
await _handle_broadcast(iface, ctrl_byte, packet)
continue

Loading…
Cancel
Save