diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index ff9546314f..34adcea7a5 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -268,6 +268,7 @@ async def _handle_qr_code_tag( ) -> protobuf.MessageType: if TYPE_CHECKING: assert isinstance(message, ThpQrCodeTag) + assert ctx.display_data.code_qr_code is not None expected_tag = sha256(ctx.display_data.code_qr_code).digest() if expected_tag != message.tag: print( diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index f6d6b6ae62..1676505f1e 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -37,6 +37,7 @@ if __debug__: if TYPE_CHECKING: from trezorio import WireInterface + from typing import Awaitable from .pairing_context import PairingContext from .session_context import GenericSessionContext @@ -113,7 +114,7 @@ class Channel: # CALLED BY THP_MAIN_LOOP - async def receive_packet(self, packet: utils.BufferType): + def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: if __debug__: log.debug( __name__, @@ -121,7 +122,7 @@ class Channel: utils.get_bytes_as_str(self.channel_id), ) - await self._handle_received_packet(packet) + self._handle_received_packet(packet) if __debug__: log.debug( @@ -133,7 +134,7 @@ class Channel: if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: self._finish_message() - await received_message_handler.handle_received_message(self, self.buffer) + return received_message_handler.handle_received_message(self, self.buffer) elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: self.is_cont_packet_expected = True else: @@ -141,14 +142,13 @@ class Channel: "Read more bytes than is the expected length of the message!" ) - async def _handle_received_packet(self, packet: utils.BufferType) -> None: + def _handle_received_packet(self, packet: utils.BufferType) -> None: ctrl_byte = packet[0] if control_byte.is_continuation(ctrl_byte): - await self._handle_cont_packet(packet) - else: - await self._handle_init_packet(packet) + return self._handle_cont_packet(packet) + return self._handle_init_packet(packet) - async def _handle_init_packet(self, packet: utils.BufferType) -> None: + def _handle_init_packet(self, packet: utils.BufferType) -> None: if __debug__: log.debug( __name__, @@ -175,7 +175,6 @@ class Channel: packet_payload, payload_length, ) - await self._buffer_packet_data(self.buffer, packet, 0) if __debug__: log.debug( @@ -190,8 +189,9 @@ class Channel: utils.get_bytes_as_str(self.channel_id), len(self.buffer), ) + return self._buffer_packet_data(self.buffer, packet, 0) - async def _handle_cont_packet(self, packet: utils.BufferType) -> None: + def _handle_cont_packet(self, packet: utils.BufferType) -> None: if __debug__: log.debug( __name__, @@ -200,7 +200,7 @@ class Channel: ) if not self.is_cont_packet_expected: raise ThpError("Continuation packet is not expected, ignoring") - await self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) + return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) def _decrypt_single_packet_payload( self, payload: utils.BufferType @@ -297,7 +297,7 @@ class Channel: buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag - async def _buffer_packet_data( + def _buffer_packet_data( self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int ): self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) @@ -309,7 +309,7 @@ class Channel: # CALLED BY WORKFLOW / SESSION CONTEXT - async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: + def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: if __debug__ and utils.EMULATOR: log.debug( __name__, @@ -323,15 +323,15 @@ class Channel: noise_payload_len = memory_manager.encode_into_buffer( self.buffer, msg, session_id ) - await self.write_and_encrypt(self.buffer[:noise_payload_len]) + return self.write_and_encrypt(self.buffer[:noise_payload_len]) - async def write_error(self, err_type: int): + def write_error(self, err_type: int) -> Awaitable[None]: msg_data = err_type.to_bytes(1, "big") length = len(msg_data) + CHECKSUM_LENGTH header = PacketHeader.get_error_header(self.get_channel_id_int(), length) - await write_payload_to_wire_and_add_checksum(self.iface, header, msg_data) + return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data) - async def write_and_encrypt(self, payload: bytes) -> None: + def write_and_encrypt(self, payload: bytes) -> None: payload_length = len(payload) self._encrypt(self.buffer, payload_length) payload_length = payload_length + TAG_LENGTH @@ -346,7 +346,7 @@ class Channel: ) ) - async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: + def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: self._prepare_write() self.write_task_spawn = loop.spawn( self._write_encrypted_payload_loop(ctrl_byte, payload) diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index bfced1461a..831f82ca48 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -166,7 +166,7 @@ class PairingContext(Context): return message_handler.wrap_protobuf_load(message.data, expected_type) async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel_ctx.write(msg) + return self.channel_ctx.write(msg) async def call( self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 03df43312a..a8295dca40 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -47,6 +47,8 @@ from .writer import ( ) if TYPE_CHECKING: + from typing import Awaitable + from trezor.messages import ThpHandshakeCompletionReqNoisePayload from .channel import Channel @@ -68,7 +70,9 @@ async def handle_received_message( import micropython micropython.mem_info() - print("Allocation count:", micropython.alloc_count()) + print( + "Allocation count:", micropython.alloc_count() # type: ignore ["alloc_count" is not a known attribute of module "micropython"] + ) except AttributeError: print("To show allocation count, create the build with TREZOR_MEMPERF=1") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer) @@ -117,7 +121,7 @@ async def handle_received_message( ) except ThpUnallocatedSessionError as e: error_message = Failure(code=FailureType.ThpUnallocatedSession) - await ctx.write(error_message, e.session_id) + ctx.write(error_message, e.session_id) except ThpDecryptionError: await ctx.write_error(ThpErrorType.DECRYPTION_FAILED) ctx.clear() @@ -128,7 +132,7 @@ async def handle_received_message( log.debug(__name__, "handle_received_message - end") -async def _send_ack(ctx: Channel, ack_bit: int) -> None: +def _send_ack(ctx: Channel, ack_bit: int) -> Awaitable[None]: ctrl_byte = control_byte.add_ack_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) header = PacketHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH) if __debug__: @@ -138,7 +142,7 @@ async def _send_ack(ctx: Channel, ack_bit: int) -> None: ctx.get_channel_id_int(), ack_bit, ) - await write_payload_to_wire_and_add_checksum(ctx.iface, header, b"") + return write_payload_to_wire_and_add_checksum(ctx.iface, header, b"") def _check_checksum(message_length: int, message_buffer: utils.BufferType): @@ -175,31 +179,27 @@ async def _handle_ack(ctx: Channel, ack_bit: int): # this await might not be executed -async def _handle_message_to_app_or_channel( +def _handle_message_to_app_or_channel( ctx: Channel, payload_length: int, message_length: int, ctrl_byte: int, -) -> None: +) -> Awaitable[None]: state = ctx.get_channel_state() if __debug__: log.debug(__name__, "state: %s", state_to_str(state)) if state is ChannelState.ENCRYPTED_TRANSPORT: - await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) - return + return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length) if state is ChannelState.TH1: - await _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte) - return + return _handle_state_TH1(ctx, payload_length, message_length, ctrl_byte) if state is ChannelState.TH2: - await _handle_state_TH2(ctx, message_length, ctrl_byte) - return + return _handle_state_TH2(ctx, message_length, ctrl_byte) if is_channel_state_pairing(state): - await _handle_pairing(ctx, message_length) - return + return _handle_pairing(ctx, message_length) raise ThpError("Unimplemented channel state") @@ -244,7 +244,7 @@ async def _handle_state_TH1( payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag # send handshake init response message - await ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload) + ctx.write_handshake_message(HANDSHAKE_INIT_RES, payload) ctx.set_channel_state(ChannelState.TH2) return @@ -323,7 +323,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) - if paired: trezor_state = thp_messages.TREZOR_STATE_PAIRED # send hanshake completion response - await ctx.write_handshake_message( + ctx.write_handshake_message( HANDSHAKE_COMP_RES, ctx.handshake.get_handshake_completion_response(trezor_state), ) diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index fcf1cc1eef..f7a6de6304 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -148,7 +148,7 @@ class GenericSessionContext(Context): return message_handler.wrap_protobuf_load(message.data, expected_type) async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel.write(msg, self.session_id) + return self.channel.write(msg, self.session_id) def get_session_state(self) -> SessionState: ... diff --git a/core/src/trezor/wire/thp/writer.py b/core/src/trezor/wire/thp/writer.py index 30ff968575..0cd32cbc8a 100644 --- a/core/src/trezor/wire/thp/writer.py +++ b/core/src/trezor/wire/thp/writer.py @@ -14,18 +14,18 @@ MESSAGE_TYPE_LENGTH = const(2) if TYPE_CHECKING: from trezorio import WireInterface - from typing import Sequence + from typing import Awaitable, Sequence -async def write_payload_to_wire_and_add_checksum( +def write_payload_to_wire_and_add_checksum( iface: WireInterface, header: PacketHeader, transport_payload: bytes -): +) -> Awaitable[None]: header_checksum: int = crc.crc32(header.to_bytes()) checksum: bytes = crc.crc32(transport_payload, header_checksum).to_bytes( CHECKSUM_LENGTH, "big" ) data = (transport_payload, checksum) - await write_payloads_to_wire(iface, header, data) + return write_payloads_to_wire(iface, header, data) async def write_payloads_to_wire( @@ -67,7 +67,16 @@ async def write_payloads_to_wire( raise Exception("Should not happen!!!") packet_number += 1 packet_offset = CONT_HEADER_LENGTH - await write_packet_to_wire(iface, packet) + + # write packet to wire (in-lined) + if __debug__: + log.debug( + __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) + ) + written_by_iface: int = 0 + while written_by_iface < len(packet): + await loop.wait(iface.iface_num() | io.POLL_WRITE) + written_by_iface = iface.write(packet) async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None: @@ -77,6 +86,6 @@ async def write_packet_to_wire(iface: WireInterface, packet: bytes) -> None: log.debug( __name__, "write_packet_to_wire: %s", utils.get_bytes_as_str(packet) ) - n = iface.write(packet) - if n == len(packet): + n_written = iface.write(packet) + if n_written == len(packet): return diff --git a/core/src/trezor/wire/thp_main.py b/core/src/trezor/wire/thp_main.py index 7dff25112f..67c8770716 100644 --- a/core/src/trezor/wire/thp_main.py +++ b/core/src/trezor/wire/thp_main.py @@ -143,7 +143,9 @@ async def _handle_allocated( raise ThpError("Channel has different WireInterface") if channel.get_channel_state() != ChannelState.UNALLOCATED: - await channel.receive_packet(packet) + x = channel.receive_packet(packet) + if x is not None: + await x async def _handle_unallocated(iface, cid) -> None: