Unallocated session error, debug log improvements

M1nd3r/thp5
M1nd3r 2 months ago
parent 893f606535
commit 4e163a2614

@ -7,8 +7,10 @@ import usb
from storage import cache_thp from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import io, log, loop, protobuf, utils from trezor import io, log, loop, protobuf, utils
from trezor.messages import ThpCreateNewSession from trezor.enums import FailureType, MessageType
from trezor.messages import Failure, ThpCreateNewSession
from trezor.wire import message_handler from trezor.wire import message_handler
from trezor.wire.errors import Error
from trezor.wire.thp import thp_messages from trezor.wire.thp import thp_messages
from ..protocol_common import Context, MessageWithType from ..protocol_common import Context, MessageWithType
@ -19,6 +21,7 @@ from .thp_messages import (
ACK_MESSAGE, ACK_MESSAGE,
CONTINUATION_PACKET, CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT, ENCRYPTED_TRANSPORT,
ERROR,
HANDSHAKE_INIT, HANDSHAKE_INIT,
InitHeader, InitHeader,
) )
@ -72,27 +75,35 @@ class Channel(Context):
def get_channel_state(self) -> int: def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big") state = int.from_bytes(self.channel_cache.state, "big")
print("channel.get_ch_state:", _state_to_str(state)) if __debug__:
log.debug(__name__, "get_channel_state: %s", _state_to_str(state))
return state return state
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def set_channel_state(self, state: ChannelState) -> None: def set_channel_state(self, state: ChannelState) -> None:
print("channel.set_ch_state:", _state_to_str(state)) if __debug__:
log.debug(__name__, "set_channel_state: %s", _state_to_str(state))
self.channel_cache.state = bytearray(state.to_bytes(1, "big")) self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
def set_buffer(self, buffer: utils.BufferType) -> None: def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer self.buffer = buffer
print("channel.set_buffer:", type(self.buffer)) if __debug__:
log.debug(__name__, "set_buffer: %s", type(self.buffer))
# CALLED BY THP_MAIN_LOOP # CALLED BY THP_MAIN_LOOP
async def receive_packet(self, packet: utils.BufferType): async def receive_packet(self, packet: utils.BufferType):
print("channel.receive_packet") if __debug__:
log.debug(__name__, "receive_packet")
ctrl_byte = packet[0] ctrl_byte = packet[0]
if _is_ctrl_byte_continuation(ctrl_byte): if _is_ctrl_byte_continuation(ctrl_byte):
await self._handle_cont_packet(packet) await self._handle_cont_packet(packet)
else: else:
await self._handle_init_packet(packet) await self._handle_init_packet(packet)
printBytes(self.buffer) if __debug__:
log.debug(__name__, "self.buffer: %s", get_bytes_as_str(self.buffer))
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
self._finish_message() self._finish_message()
await self._handle_completed_message() await self._handle_completed_message()
@ -104,7 +115,8 @@ class Channel(Context):
) )
async def _handle_init_packet(self, packet: utils.BufferType): async def _handle_init_packet(self, packet: utils.BufferType):
print("channel._handle_init_packet") if __debug__:
log.debug(__name__, "handle_init_packet")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
self.expected_payload_length = payload_length self.expected_payload_length = payload_length
packet_payload = packet[5:] packet_payload = packet[5:]
@ -133,20 +145,27 @@ class Channel(Context):
payload_length, self.buffer payload_length, self.buffer
) )
except Exception as e: except Exception as e:
print(e) if __debug__:
print("channel._handle_init_packet - payload len", payload_length) log.exception(__name__, e)
print("channel._handle_init_packet - buffer len", len(self.buffer)) if __debug__:
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
if __debug__:
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
await self._buffer_packet_data(self.buffer, packet, 0) await self._buffer_packet_data(self.buffer, packet, 0)
print("channel._handle_init_packet - end") if __debug__:
log.debug(__name__, "channel._handle_init_packet - end")
async def _handle_cont_packet(self, packet: utils.BufferType): async def _handle_cont_packet(self, packet: utils.BufferType):
print("channel._handle_cont_packet") if __debug__:
log.debug(__name__, "handle_cont_packet")
if not self.is_cont_packet_expected: if not self.is_cont_packet_expected:
raise ThpError("Continuation packet is not expected, ignoring") raise ThpError("Continuation packet is not expected, ignoring")
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
async def _handle_completed_message(self) -> None: async def _handle_completed_message(self) -> None:
print("channel._handle_completed_message") if __debug__:
log.debug(__name__, "handle_completed_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
message_length = payload_length + INIT_DATA_OFFSET message_length = payload_length + INIT_DATA_OFFSET
@ -154,7 +173,12 @@ class Channel(Context):
# Synchronization process # Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4 sync_bit = (ctrl_byte & 0x10) >> 4
print("channel._handle_completed_message - sync bit of message:", sync_bit) if __debug__:
log.debug(
__name__,
"handle_completed_message - sync bit of message: %d",
sync_bit,
)
# 1: Handle ACKs # 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte): if _is_ctrl_byte_ack(ctrl_byte):
@ -179,15 +203,19 @@ class Channel(Context):
await self._handle_message_to_app_or_channel( await self._handle_message_to_app_or_channel(
payload_length, message_length, ctrl_byte, sync_bit payload_length, message_length, ctrl_byte, sync_bit
) )
print("channel._handle_completed_message - end") if __debug__:
log.debug(__name__, "handle_completed_message - end")
def _check_checksum(self, message_length: int): def _check_checksum(self, message_length: int):
print("channel._check_checksum") if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid( if not checksum.is_valid(
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length], checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
data=self.buffer[: message_length - CHECKSUM_LENGTH], data=self.buffer[: message_length - CHECKSUM_LENGTH],
): ):
self._todo_clear_buffer() self._todo_clear_buffer()
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.") raise ThpError("Invalid checksum, ignoring message.")
async def _handle_message_to_app_or_channel( async def _handle_message_to_app_or_channel(
@ -195,7 +223,7 @@ class Channel(Context):
) -> None: ) -> None:
state = self.get_channel_state() state = self.get_channel_state()
if __debug__: if __debug__:
log.debug(__name__, "state: " + _state_to_str(state)) log.debug(__name__, "state: %s", _state_to_str(state))
if state is ChannelState.TH1: if state is ChannelState.TH1:
await self._handle_state_TH1(payload_length, message_length, sync_bit) await self._handle_state_TH1(payload_length, message_length, sync_bit)
@ -206,7 +234,7 @@ class Channel(Context):
raise ThpError("Message is not encrypted. Ignoring") raise ThpError("Message is not encrypted. Ignoring")
if state is ChannelState.ENCRYPTED_TRANSPORT: if state is ChannelState.ENCRYPTED_TRANSPORT:
self._handle_state_ENCRYPTED_TRANSPORT(message_length) await self._handle_state_ENCRYPTED_TRANSPORT(message_length)
return return
if state is ChannelState.TH2: if state is ChannelState.TH2:
@ -239,7 +267,8 @@ class Channel(Context):
return return
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None: async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
print("channel._handle_state_TH2") if __debug__:
log.debug(__name__, "handle_state_TH2")
host_encrypted_static_pubkey = self.buffer[ host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
] ]
@ -249,10 +278,13 @@ class Channel(Context):
+ TAG_LENGTH : message_length + TAG_LENGTH : message_length
- CHECKSUM_LENGTH - CHECKSUM_LENGTH
] ]
print( if __debug__:
host_encrypted_static_pubkey, log.debug(
handshake_completion_request_noise_payload, __name__,
) # TODO remove "host static pubkey: %s, noise payload: %s",
get_bytes_as_str(host_encrypted_static_pubkey),
get_bytes_as_str(handshake_completion_request_noise_payload),
)
# send hanshake completion response # send hanshake completion response
loop.schedule( loop.schedule(
@ -262,8 +294,9 @@ class Channel(Context):
) )
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
print("channel._handle_state_ENCRYPTED_TRANSPORT") if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
self._decrypt_buffer(message_length) self._decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0: if session_id == 0:
@ -271,10 +304,16 @@ class Channel(Context):
return return
if session_id not in self.sessions: if session_id not in self.sessions:
await self.write_error(
FailureType.ThpUnallocatedSession, "Unallocated session"
)
raise ThpError("Unalloacted session") raise ThpError("Unalloacted session")
session_state = self.sessions[session_id].get_session_state() session_state = self.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED: if session_state is SessionState.UNALLOCATED:
await self.write_error(
FailureType.ThpUnallocatedSession, "Unallocated session"
)
raise ThpError("Unalloacted session") raise ThpError("Unalloacted session")
self.sessions[session_id].incoming_message.publish( self.sessions[session_id].incoming_message.publish(
@ -296,11 +335,17 @@ class Channel(Context):
expected_type = protobuf.type_for_wire(message_type) expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type) message = message_handler.wrap_protobuf_load(buf, expected_type)
print("channel._handle_channel_message:", message) if __debug__:
log.debug(__name__, "handle_channel_message: %s", message)
# TODO handle other messages than CreateNewSession # TODO handle other messages than CreateNewSession
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(message, ThpCreateNewSession) assert isinstance(message, ThpCreateNewSession)
print("channel._handle_channel_message - passphrase:", message.passphrase) if __debug__:
log.debug(
__name__,
"handle_channel_message - passphrase: %s",
message.passphrase,
)
# await thp_messages.handle_CreateNewSession(message) # await thp_messages.handle_CreateNewSession(message)
if message.passphrase is not None: if message.passphrase is not None:
new_session_id: int = self.create_new_session(message.passphrase) new_session_id: int = self.create_new_session(message.passphrase)
@ -311,7 +356,10 @@ class Channel(Context):
message_size: int = thp_messages.get_new_session_message( message_size: int = thp_messages.get_new_session_message(
bufferrone, new_session_id bufferrone, new_session_id
) )
print(message_size) # TODO adjust if __debug__:
log.debug(
__name__, "handle_channel_message - message size: %d", message_size
)
loop.schedule(self.write_and_encrypt(bufferrone)) loop.schedule(self.write_and_encrypt(bufferrone))
# TODO not finished # TODO not finished
@ -332,7 +380,8 @@ class Channel(Context):
) )
def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None: def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None:
print("channel._encrypt") if __debug__:
log.debug(__name__, "encrypt")
min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
if len(buffer) < min_required_length or not isinstance(buffer, bytearray): if len(buffer) < min_required_length or not isinstance(buffer, bytearray):
new_buffer = bytearray(min_required_length) new_buffer = bytearray(min_required_length)
@ -359,18 +408,16 @@ class Channel(Context):
async def _send_ack(self, ack_bit: int) -> None: async def _send_ack(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader( header = InitHeader(ctrl_byte, self.get_channel_id_int(), CHECKSUM_LENGTH)
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
)
chksum = checksum.compute(header.to_bytes()) chksum = checksum.compute(header.to_bytes())
if __debug__: if __debug__:
log.debug( log.debug(
__name__, __name__,
"Writing ACK message to a channel with id: %d, sync bit: %d", "Writing ACK message to a channel with id: %d, sync bit: %d",
int.from_bytes(self.channel_id, "big"), self.get_channel_id_int(),
ack_bit, ack_bit,
) )
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH) await self._write_payload_to_wire(header, chksum)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0: if sync_bit == 0:
@ -382,10 +429,28 @@ class Channel(Context):
# CALLED BY WORKFLOW / SESSION CONTEXT # CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
print("channel.write:" + msg.MESSAGE_NAME) if __debug__:
log.debug(__name__, "channel.write: %s", msg.MESSAGE_NAME)
noise_payload_len = self._encode_into_buffer(msg, session_id) noise_payload_len = self._encode_into_buffer(msg, session_id)
await self.write_and_encrypt(self.buffer[:noise_payload_len]) await self.write_and_encrypt(self.buffer[:noise_payload_len])
async def write_error(self, err_type: FailureType, message: str) -> None:
if __debug__:
log.debug(__name__, "write_error")
msg_size = self._encode_error_into_buffer(err_type, message)
data_length = MESSAGE_TYPE_LENGTH + msg_size
header: InitHeader = InitHeader(
ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH
)
chksum = checksum.compute(
header.to_bytes() + memoryview(self.buffer[:data_length])
)
utils.memcpy(self.buffer, data_length, chksum, 0)
await self._write_payload_to_wire(
header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH])
)
async def write_and_encrypt(self, payload: bytes) -> None: async def write_and_encrypt(self, payload: bytes) -> None:
payload_length = len(payload) payload_length = len(payload)
@ -394,29 +459,31 @@ class Channel(Context):
self._encrypt(self.buffer, payload_length) self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH payload_length = payload_length + TAG_LENGTH
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length])) loop.schedule(
self._write_encrypted_payload_loop(memoryview(self.buffer[:payload_length]))
)
async def _write_encrypted_payload_loop(self, payload: bytes) -> None: async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
print("channel._write_encrypted_payload_loop") if __debug__:
log.debug(__name__, "write_encrypted_payload_loop")
payload_len = len(payload) + CHECKSUM_LENGTH payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = THP.sync_get_send_bit(self.channel_cache) sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit) ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
header = InitHeader( header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
)
chksum = checksum.compute(header.to_bytes() + payload) chksum = checksum.compute(header.to_bytes() + payload)
payload = payload + chksum payload = payload + chksum
# TODO add condition that disallows to write when can_send_message is false # TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False) THP.sync_set_can_send_message(self.channel_cache, False)
while True: while True:
print( if __debug__:
"channel._write_encrypted_payload_loop - loop start, sync_bit:", log.debug(
(header.ctrl_byte & 0x10) >> 4, __name__,
" sync_send_bit:", "write_encrypted_payload_loop - loop start, sync_bit: %d, sync_send_bit: %d",
THP.sync_get_send_bit(self.channel_cache), (header.ctrl_byte & 0x10) >> 4,
) THP.sync_get_send_bit(self.channel_cache),
await self._write_payload_to_wire(header, payload, payload_len) )
await self._write_payload_to_wire(header, payload)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try: try:
await self.waiting_for_ack_timeout await self.waiting_for_ack_timeout
@ -424,29 +491,36 @@ class Channel(Context):
THP.sync_set_send_bit_to_opposite(self.channel_cache) THP.sync_set_send_bit_to_opposite(self.channel_cache)
break break
async def _write_payload_to_wire( async def _write_payload_to_wire(self, header: InitHeader, payload: bytes):
self, header: InitHeader, payload: bytes, payload_len: int if __debug__:
): log.debug(__name__, "write_payload_to_wire")
print("chanel._write_payload_to_wire")
# prepare the report buffer with header data # prepare the report buffer with header data
payload_len = len(payload)
report = bytearray(REPORT_LENGTH) report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report) header.pack_to_buffer(report)
# write initial report # write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await self._write_report_to_wire(report) await self._write_report_to_wire(report)
# if we have more data to write, use continuation reports for it # if we have more data to write, use continuation reports for it
if nwritten < payload_len: if nwritten < payload_len:
header.pack_to_cont_buffer(report) header.pack_to_cont_buffer(report)
while nwritten < payload_len: while nwritten < payload_len:
if nwritten >= payload_len - REPORT_LENGTH:
report = bytearray(REPORT_LENGTH)
header.pack_to_cont_buffer(report)
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await self._write_report_to_wire(report) await self._write_report_to_wire(report)
async def _write_report_to_wire(self, report: utils.BufferType) -> None: async def _write_report_to_wire(self, report: utils.BufferType) -> None:
while True: while True:
await loop.wait(self.iface.iface_num() | io.POLL_WRITE) await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
printBytes(report) # TODO remove if __debug__:
log.debug(
__name__, "write_report_to_wire: %s", get_bytes_as_str(report)
)
n = self.iface.write(report) n = self.iface.write(report)
if n == len(report): if n == len(report):
return return
@ -460,41 +534,52 @@ class Channel(Context):
assert msg.MESSAGE_WIRE_TYPE is not None assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg) msg_size = protobuf.encoded_length(msg)
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
payload_size = offset + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(self.buffer) or not isinstance( if required_min_size > len(self.buffer):
self.buffer, bytearray # message is too big, we need to allocate a new buffer
):
# message is too big or buffer is not bytearray, we need to allocate a new buffer
self.buffer = bytearray(required_min_size) self.buffer = bytearray(required_min_size)
buffer = self.buffer buffer = self.buffer
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
msg_type_bytes = int.to_bytes(msg.MESSAGE_WIRE_TYPE, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, 0, session_id_bytes, 0) _encode_session_into_buffer(memoryview(buffer), session_id)
utils.memcpy(buffer, SESSION_ID_LENGTH, msg_type_bytes, 0) _encode_message_type_into_buffer(
assert isinstance(buffer, bytearray) memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
msg_size = protobuf.encode(buffer[offset:], msg) )
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size return payload_size
def _encode_error_into_buffer(self, err_code: FailureType, message: str) -> int:
error_message: protobuf.MessageType = Failure(code=err_code, message=message)
_encode_message_type_into_buffer(memoryview(self.buffer), MessageType.Failure)
_encode_message_into_buffer(
memoryview(self.buffer), error_message, MESSAGE_TYPE_LENGTH
)
return protobuf.encoded_length(error_message)
def create_new_session( def create_new_session(
self, self,
passphrase="", passphrase="",
) -> int: ) -> int:
print("channel.create_new_session") if __debug__:
log.debug(__name__, " create_new_session")
from trezor.wire.thp.session_context import SessionContext from trezor.wire.thp.session_context import SessionContext
session = SessionContext.create_new_session(self) session = SessionContext.create_new_session(self)
self.sessions[session.session_id] = session self.sessions[session.session_id] = session
loop.schedule(session.handle()) loop.schedule(session.handle())
print( if __debug__:
"channel.create_new_session - new session created. Session id:", log.debug(
session.session_id, __name__,
) "create_new_session - new session created. Session id: %d",
print(self.sessions) session.session_id,
)
if __debug__:
print(self.sessions)
return session.session_id return session.session_id
def _todo_clear_buffer(self): def _todo_clear_buffer(self):
@ -504,15 +589,21 @@ class Channel(Context):
# TODO add debug logging to ACK handling # TODO add debug logging to ACK handling
def _handle_received_ACK(self, sync_bit: int) -> None: def _handle_received_ACK(self, sync_bit: int) -> None:
if self._ack_is_not_expected(): if self._ack_is_not_expected():
print("channel._handle_received_ACK - ack is not expected") if __debug__:
log.debug(__name__, "handle_received_ACK - ack is not expected")
return return
if self._ack_has_incorrect_sync_bit(sync_bit): if self._ack_has_incorrect_sync_bit(sync_bit):
print("channel._handle_received_ACK - ack has incorrect sync bit") if __debug__:
log.debug(
__name__,
"handle_received_ACK - ack has incorrect sync bit",
)
return return
if self.waiting_for_ack_timeout is not None: if self.waiting_for_ack_timeout is not None:
self.waiting_for_ack_timeout.close() self.waiting_for_ack_timeout.close()
print("channel._handle_received_ACK - closed waiting for ack") if __debug__:
log.debug(__name__, "handle_received_ACK - closed waiting for ack")
THP.sync_set_can_send_message(self.channel_cache, True) THP.sync_set_can_send_message(self.channel_cache, True)
@ -558,10 +649,13 @@ def _get_buffer_for_message(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType: ) -> utils.BufferType:
length = payload_length + INIT_DATA_OFFSET length = payload_length + INIT_DATA_OFFSET
print("channel._get_buffer_for_message - length", length) if __debug__:
print( log.debug(__name__, "get_buffer_for_message - length: %d", length)
"channel._get_buffer_for_message - existing buffer type", type(existing_buffer) log.debug(
) __name__,
"get_buffer_for_message - existing buffer type: %s",
type(existing_buffer),
)
if length > max_length: if length > max_length:
raise ThpError("Message too large") raise ThpError("Message too large")
@ -606,6 +700,26 @@ def is_channel_state_pairing(state: int) -> bool:
return False return False
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _state_to_str(state: int) -> str: def _state_to_str(state: int) -> str:
name = { name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__") v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
@ -615,5 +729,22 @@ def _state_to_str(state: int) -> str:
return "UNKNOWN_STATE" return "UNKNOWN_STATE"
def printBytes(a): def get_bytes_as_str(a):
print(hexlify(a).decode("utf-8")) return hexlify(a).decode("utf-8")
def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, ThpError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
# NOTE: when receiving generic `FirmwareError` on non-debug build,
# change the `if __debug__` to `if True` to get the full error message.
if __debug__:
message = str(exc)
else:
message = "Firmware error"
return Failure(code=FailureType.FirmwareError, message=message)

@ -12,7 +12,7 @@ CONTINUATION_PACKET = 0x80
ENCRYPTED_TRANSPORT = 0x02 ENCRYPTED_TRANSPORT = 0x02
HANDSHAKE_INIT = 0x00 HANDSHAKE_INIT = 0x00
ACK_MESSAGE = 0x20 ACK_MESSAGE = 0x20
_ERROR = 0x42 ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40 CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41 _CHANNEL_ALLOCATION_RES = 0x41
@ -23,7 +23,7 @@ TREZOR_STATE_PAIRED = b"\x01"
class InitHeader: class InitHeader:
format_str = ">BHH" format_str = ">BHH"
def __init__(self, ctrl_byte, cid: int, length: int) -> None: def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte self.ctrl_byte = ctrl_byte
self.cid = cid self.cid = cid
self.length = length self.length = length
@ -48,7 +48,7 @@ class InitHeader:
@classmethod @classmethod
def get_error_header(cls, cid, length): def get_error_header(cls, cid, length):
return cls(_ERROR, cid, length) return cls(ERROR, cid, length)
@classmethod @classmethod
def get_channel_allocation_response_header(cls, length): def get_channel_allocation_response_header(cls, length):

@ -45,7 +45,8 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
while True: while True:
try: try:
print("thp_v1.thp_main_loop") if __debug__:
log.debug(__name__, "thp_main_loop")
packet = await read packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet) ctrl_byte, cid = ustruct.unpack(">BH", packet)

Loading…
Cancel
Save