diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 2df9cd222..23c9f66e8 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -7,8 +7,10 @@ import usb from storage import cache_thp from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache from trezor import io, log, loop, protobuf, utils -from trezor.messages import ThpCreateNewSession +from trezor.enums import FailureType, MessageType +from trezor.messages import Failure, ThpCreateNewSession from trezor.wire import message_handler +from trezor.wire.errors import Error from trezor.wire.thp import thp_messages from ..protocol_common import Context, MessageWithType @@ -19,6 +21,7 @@ from .thp_messages import ( ACK_MESSAGE, CONTINUATION_PACKET, ENCRYPTED_TRANSPORT, + ERROR, HANDSHAKE_INIT, InitHeader, ) @@ -72,27 +75,35 @@ class Channel(Context): def get_channel_state(self) -> int: state = int.from_bytes(self.channel_cache.state, "big") - print("channel.get_ch_state:", _state_to_str(state)) + if __debug__: + log.debug(__name__, "get_channel_state: %s", _state_to_str(state)) return state + def get_channel_id_int(self) -> int: + return int.from_bytes(self.channel_id, "big") + def set_channel_state(self, state: ChannelState) -> None: - print("channel.set_ch_state:", _state_to_str(state)) + if __debug__: + log.debug(__name__, "set_channel_state: %s", _state_to_str(state)) self.channel_cache.state = bytearray(state.to_bytes(1, "big")) def set_buffer(self, buffer: utils.BufferType) -> None: self.buffer = buffer - print("channel.set_buffer:", type(self.buffer)) + if __debug__: + log.debug(__name__, "set_buffer: %s", type(self.buffer)) # CALLED BY THP_MAIN_LOOP async def receive_packet(self, packet: utils.BufferType): - print("channel.receive_packet") + if __debug__: + log.debug(__name__, "receive_packet") ctrl_byte = packet[0] if _is_ctrl_byte_continuation(ctrl_byte): await self._handle_cont_packet(packet) else: await self._handle_init_packet(packet) - printBytes(self.buffer) + if __debug__: + log.debug(__name__, "self.buffer: %s", get_bytes_as_str(self.buffer)) if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: self._finish_message() await self._handle_completed_message() @@ -104,7 +115,8 @@ class Channel(Context): ) async def _handle_init_packet(self, packet: utils.BufferType): - print("channel._handle_init_packet") + if __debug__: + log.debug(__name__, "handle_init_packet") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) self.expected_payload_length = payload_length packet_payload = packet[5:] @@ -133,20 +145,27 @@ class Channel(Context): payload_length, self.buffer ) except Exception as e: - print(e) - print("channel._handle_init_packet - payload len", payload_length) - print("channel._handle_init_packet - buffer len", len(self.buffer)) + if __debug__: + log.exception(__name__, e) + if __debug__: + log.debug(__name__, "handle_init_packet - payload len: %d", payload_length) + if __debug__: + log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer)) + await self._buffer_packet_data(self.buffer, packet, 0) - print("channel._handle_init_packet - end") + if __debug__: + log.debug(__name__, "channel._handle_init_packet - end") async def _handle_cont_packet(self, packet: utils.BufferType): - print("channel._handle_cont_packet") + if __debug__: + log.debug(__name__, "handle_cont_packet") if not self.is_cont_packet_expected: raise ThpError("Continuation packet is not expected, ignoring") await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) async def _handle_completed_message(self) -> None: - print("channel._handle_completed_message") + if __debug__: + log.debug(__name__, "handle_completed_message") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) message_length = payload_length + INIT_DATA_OFFSET @@ -154,7 +173,12 @@ class Channel(Context): # Synchronization process sync_bit = (ctrl_byte & 0x10) >> 4 - print("channel._handle_completed_message - sync bit of message:", sync_bit) + if __debug__: + log.debug( + __name__, + "handle_completed_message - sync bit of message: %d", + sync_bit, + ) # 1: Handle ACKs if _is_ctrl_byte_ack(ctrl_byte): @@ -179,15 +203,19 @@ class Channel(Context): await self._handle_message_to_app_or_channel( payload_length, message_length, ctrl_byte, sync_bit ) - print("channel._handle_completed_message - end") + if __debug__: + log.debug(__name__, "handle_completed_message - end") def _check_checksum(self, message_length: int): - print("channel._check_checksum") + if __debug__: + log.debug(__name__, "check_checksum") if not checksum.is_valid( checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length], data=self.buffer[: message_length - CHECKSUM_LENGTH], ): self._todo_clear_buffer() + if __debug__: + log.debug(__name__, "Invalid checksum, ignoring message.") raise ThpError("Invalid checksum, ignoring message.") async def _handle_message_to_app_or_channel( @@ -195,7 +223,7 @@ class Channel(Context): ) -> None: state = self.get_channel_state() if __debug__: - log.debug(__name__, "state: " + _state_to_str(state)) + log.debug(__name__, "state: %s", _state_to_str(state)) if state is ChannelState.TH1: await self._handle_state_TH1(payload_length, message_length, sync_bit) @@ -206,7 +234,7 @@ class Channel(Context): raise ThpError("Message is not encrypted. Ignoring") if state is ChannelState.ENCRYPTED_TRANSPORT: - self._handle_state_ENCRYPTED_TRANSPORT(message_length) + await self._handle_state_ENCRYPTED_TRANSPORT(message_length) return if state is ChannelState.TH2: @@ -239,7 +267,8 @@ class Channel(Context): return async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None: - print("channel._handle_state_TH2") + if __debug__: + log.debug(__name__, "handle_state_TH2") host_encrypted_static_pubkey = self.buffer[ INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH ] @@ -249,10 +278,13 @@ class Channel(Context): + TAG_LENGTH : message_length - CHECKSUM_LENGTH ] - print( - host_encrypted_static_pubkey, - handshake_completion_request_noise_payload, - ) # TODO remove + if __debug__: + log.debug( + __name__, + "host static pubkey: %s, noise payload: %s", + get_bytes_as_str(host_encrypted_static_pubkey), + get_bytes_as_str(handshake_completion_request_noise_payload), + ) # send hanshake completion response loop.schedule( @@ -262,8 +294,9 @@ class Channel(Context): ) self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) - def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: - print("channel._handle_state_ENCRYPTED_TRANSPORT") + async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: + if __debug__: + log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") self._decrypt_buffer(message_length) session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) if session_id == 0: @@ -271,10 +304,16 @@ class Channel(Context): return if session_id not in self.sessions: + await self.write_error( + FailureType.ThpUnallocatedSession, "Unallocated session" + ) raise ThpError("Unalloacted session") session_state = self.sessions[session_id].get_session_state() if session_state is SessionState.UNALLOCATED: + await self.write_error( + FailureType.ThpUnallocatedSession, "Unallocated session" + ) raise ThpError("Unalloacted session") self.sessions[session_id].incoming_message.publish( @@ -296,11 +335,17 @@ class Channel(Context): expected_type = protobuf.type_for_wire(message_type) message = message_handler.wrap_protobuf_load(buf, expected_type) - print("channel._handle_channel_message:", message) + if __debug__: + log.debug(__name__, "handle_channel_message: %s", message) # TODO handle other messages than CreateNewSession if TYPE_CHECKING: assert isinstance(message, ThpCreateNewSession) - print("channel._handle_channel_message - passphrase:", message.passphrase) + if __debug__: + log.debug( + __name__, + "handle_channel_message - passphrase: %s", + message.passphrase, + ) # await thp_messages.handle_CreateNewSession(message) if message.passphrase is not None: new_session_id: int = self.create_new_session(message.passphrase) @@ -311,7 +356,10 @@ class Channel(Context): message_size: int = thp_messages.get_new_session_message( bufferrone, new_session_id ) - print(message_size) # TODO adjust + if __debug__: + log.debug( + __name__, "handle_channel_message - message size: %d", message_size + ) loop.schedule(self.write_and_encrypt(bufferrone)) # TODO not finished @@ -332,7 +380,8 @@ class Channel(Context): ) def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None: - print("channel._encrypt") + if __debug__: + log.debug(__name__, "encrypt") min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH if len(buffer) < min_required_length or not isinstance(buffer, bytearray): new_buffer = bytearray(min_required_length) @@ -359,18 +408,16 @@ class Channel(Context): async def _send_ack(self, ack_bit: int) -> None: ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) - header = InitHeader( - ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH - ) + header = InitHeader(ctrl_byte, self.get_channel_id_int(), CHECKSUM_LENGTH) chksum = checksum.compute(header.to_bytes()) if __debug__: log.debug( __name__, "Writing ACK message to a channel with id: %d, sync bit: %d", - int.from_bytes(self.channel_id, "big"), + self.get_channel_id_int(), ack_bit, ) - await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH) + await self._write_payload_to_wire(header, chksum) def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): if sync_bit == 0: @@ -382,10 +429,28 @@ class Channel(Context): # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: - print("channel.write:" + msg.MESSAGE_NAME) + if __debug__: + log.debug(__name__, "channel.write: %s", msg.MESSAGE_NAME) noise_payload_len = self._encode_into_buffer(msg, session_id) await self.write_and_encrypt(self.buffer[:noise_payload_len]) + async def write_error(self, err_type: FailureType, message: str) -> None: + if __debug__: + log.debug(__name__, "write_error") + msg_size = self._encode_error_into_buffer(err_type, message) + data_length = MESSAGE_TYPE_LENGTH + msg_size + header: InitHeader = InitHeader( + ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH + ) + chksum = checksum.compute( + header.to_bytes() + memoryview(self.buffer[:data_length]) + ) + + utils.memcpy(self.buffer, data_length, chksum, 0) + await self._write_payload_to_wire( + header, memoryview(self.buffer[: data_length + CHECKSUM_LENGTH]) + ) + async def write_and_encrypt(self, payload: bytes) -> None: payload_length = len(payload) @@ -394,29 +459,31 @@ class Channel(Context): self._encrypt(self.buffer, payload_length) payload_length = payload_length + TAG_LENGTH - loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length])) + loop.schedule( + self._write_encrypted_payload_loop(memoryview(self.buffer[:payload_length])) + ) async def _write_encrypted_payload_loop(self, payload: bytes) -> None: - print("channel._write_encrypted_payload_loop") + if __debug__: + log.debug(__name__, "write_encrypted_payload_loop") payload_len = len(payload) + CHECKSUM_LENGTH sync_bit = THP.sync_get_send_bit(self.channel_cache) ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit) - header = InitHeader( - ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len - ) + header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len) chksum = checksum.compute(header.to_bytes() + payload) payload = payload + chksum # TODO add condition that disallows to write when can_send_message is false THP.sync_set_can_send_message(self.channel_cache, False) while True: - print( - "channel._write_encrypted_payload_loop - loop start, sync_bit:", - (header.ctrl_byte & 0x10) >> 4, - " sync_send_bit:", - THP.sync_get_send_bit(self.channel_cache), - ) - await self._write_payload_to_wire(header, payload, payload_len) + if __debug__: + log.debug( + __name__, + "write_encrypted_payload_loop - loop start, sync_bit: %d, sync_send_bit: %d", + (header.ctrl_byte & 0x10) >> 4, + THP.sync_get_send_bit(self.channel_cache), + ) + await self._write_payload_to_wire(header, payload) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) try: await self.waiting_for_ack_timeout @@ -424,29 +491,36 @@ class Channel(Context): THP.sync_set_send_bit_to_opposite(self.channel_cache) break - async def _write_payload_to_wire( - self, header: InitHeader, payload: bytes, payload_len: int - ): - print("chanel._write_payload_to_wire") + async def _write_payload_to_wire(self, header: InitHeader, payload: bytes): + if __debug__: + log.debug(__name__, "write_payload_to_wire") # prepare the report buffer with header data + payload_len = len(payload) report = bytearray(REPORT_LENGTH) header.pack_to_buffer(report) # write initial report nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) + await self._write_report_to_wire(report) # if we have more data to write, use continuation reports for it if nwritten < payload_len: header.pack_to_cont_buffer(report) while nwritten < payload_len: + if nwritten >= payload_len - REPORT_LENGTH: + report = bytearray(REPORT_LENGTH) + header.pack_to_cont_buffer(report) nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) await self._write_report_to_wire(report) async def _write_report_to_wire(self, report: utils.BufferType) -> None: while True: await loop.wait(self.iface.iface_num() | io.POLL_WRITE) - printBytes(report) # TODO remove + if __debug__: + log.debug( + __name__, "write_report_to_wire: %s", get_bytes_as_str(report) + ) n = self.iface.write(report) if n == len(report): return @@ -460,41 +534,52 @@ class Channel(Context): assert msg.MESSAGE_WIRE_TYPE is not None msg_size = protobuf.encoded_length(msg) - offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH - payload_size = offset + msg_size + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH - if required_min_size > len(self.buffer) or not isinstance( - self.buffer, bytearray - ): - # message is too big or buffer is not bytearray, we need to allocate a new buffer + if required_min_size > len(self.buffer): + # message is too big, we need to allocate a new buffer self.buffer = bytearray(required_min_size) buffer = self.buffer - session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") - msg_type_bytes = int.to_bytes(msg.MESSAGE_WIRE_TYPE, MESSAGE_TYPE_LENGTH, "big") - utils.memcpy(buffer, 0, session_id_bytes, 0) - utils.memcpy(buffer, SESSION_ID_LENGTH, msg_type_bytes, 0) - assert isinstance(buffer, bytearray) - msg_size = protobuf.encode(buffer[offset:], msg) + _encode_session_into_buffer(memoryview(buffer), session_id) + _encode_message_type_into_buffer( + memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH + ) + _encode_message_into_buffer( + memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + ) + return payload_size + def _encode_error_into_buffer(self, err_code: FailureType, message: str) -> int: + error_message: protobuf.MessageType = Failure(code=err_code, message=message) + _encode_message_type_into_buffer(memoryview(self.buffer), MessageType.Failure) + _encode_message_into_buffer( + memoryview(self.buffer), error_message, MESSAGE_TYPE_LENGTH + ) + return protobuf.encoded_length(error_message) + def create_new_session( self, passphrase="", ) -> int: - print("channel.create_new_session") + if __debug__: + log.debug(__name__, " create_new_session") from trezor.wire.thp.session_context import SessionContext session = SessionContext.create_new_session(self) self.sessions[session.session_id] = session loop.schedule(session.handle()) - print( - "channel.create_new_session - new session created. Session id:", - session.session_id, - ) - print(self.sessions) + if __debug__: + log.debug( + __name__, + "create_new_session - new session created. Session id: %d", + session.session_id, + ) + if __debug__: + print(self.sessions) return session.session_id def _todo_clear_buffer(self): @@ -504,15 +589,21 @@ class Channel(Context): # TODO add debug logging to ACK handling def _handle_received_ACK(self, sync_bit: int) -> None: if self._ack_is_not_expected(): - print("channel._handle_received_ACK - ack is not expected") + if __debug__: + log.debug(__name__, "handle_received_ACK - ack is not expected") return if self._ack_has_incorrect_sync_bit(sync_bit): - print("channel._handle_received_ACK - ack has incorrect sync bit") + if __debug__: + log.debug( + __name__, + "handle_received_ACK - ack has incorrect sync bit", + ) return if self.waiting_for_ack_timeout is not None: self.waiting_for_ack_timeout.close() - print("channel._handle_received_ACK - closed waiting for ack") + if __debug__: + log.debug(__name__, "handle_received_ACK - closed waiting for ack") THP.sync_set_can_send_message(self.channel_cache, True) @@ -558,10 +649,13 @@ def _get_buffer_for_message( payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN ) -> utils.BufferType: length = payload_length + INIT_DATA_OFFSET - print("channel._get_buffer_for_message - length", length) - print( - "channel._get_buffer_for_message - existing buffer type", type(existing_buffer) - ) + if __debug__: + log.debug(__name__, "get_buffer_for_message - length: %d", length) + log.debug( + __name__, + "get_buffer_for_message - existing buffer type: %s", + type(existing_buffer), + ) if length > max_length: raise ThpError("Message too large") @@ -606,6 +700,26 @@ def is_channel_state_pairing(state: int) -> bool: return False +def _encode_session_into_buffer( + buffer: memoryview, session_id: int, buffer_offset: int = 0 +) -> None: + session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") + utils.memcpy(buffer, buffer_offset, session_id_bytes, 0) + + +def _encode_message_type_into_buffer( + buffer: memoryview, message_type: int, offset: int = 0 +) -> None: + msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big") + utils.memcpy(buffer, offset, msg_type_bytes, 0) + + +def _encode_message_into_buffer( + buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 +) -> None: + protobuf.encode(memoryview(buffer[buffer_offset:]), message) + + def _state_to_str(state: int) -> str: name = { v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__") @@ -615,5 +729,22 @@ def _state_to_str(state: int) -> str: return "UNKNOWN_STATE" -def printBytes(a): - print(hexlify(a).decode("utf-8")) +def get_bytes_as_str(a): + return hexlify(a).decode("utf-8") + + +def failure(exc: BaseException) -> Failure: + if isinstance(exc, Error): + return Failure(code=exc.code, message=exc.message) + elif isinstance(exc, loop.TaskClosed): + return Failure(code=FailureType.ActionCancelled, message="Cancelled") + elif isinstance(exc, ThpError): + return Failure(code=FailureType.InvalidSession, message="Invalid session") + else: + # NOTE: when receiving generic `FirmwareError` on non-debug build, + # change the `if __debug__` to `if True` to get the full error message. + if __debug__: + message = str(exc) + else: + message = "Firmware error" + return Failure(code=FailureType.FirmwareError, message=message) diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index 585811870..d08e3e446 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -12,7 +12,7 @@ CONTINUATION_PACKET = 0x80 ENCRYPTED_TRANSPORT = 0x02 HANDSHAKE_INIT = 0x00 ACK_MESSAGE = 0x20 -_ERROR = 0x42 +ERROR = 0x42 CHANNEL_ALLOCATION_REQ = 0x40 _CHANNEL_ALLOCATION_RES = 0x41 @@ -23,7 +23,7 @@ TREZOR_STATE_PAIRED = b"\x01" class InitHeader: format_str = ">BHH" - def __init__(self, ctrl_byte, cid: int, length: int) -> None: + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: self.ctrl_byte = ctrl_byte self.cid = cid self.length = length @@ -48,7 +48,7 @@ class InitHeader: @classmethod def get_error_header(cls, cid, length): - return cls(_ERROR, cid, length) + return cls(ERROR, cid, length) @classmethod def get_channel_allocation_response_header(cls, length): diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index e77ec434b..4c9562e46 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -45,7 +45,8 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): while True: try: - print("thp_v1.thp_main_loop") + if __debug__: + log.debug(__name__, "thp_main_loop") packet = await read ctrl_byte, cid = ustruct.unpack(">BH", packet)