diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 8ca0c0694..6122e7b87 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -41,8 +41,6 @@ MESSAGE_TYPE_LENGTH = const(2) REPORT_LENGTH = const(64) MAX_PAYLOAD_LEN = const(60000) -_ACK_MESSAGE = 0x20 - class Channel(Context): def __init__(self, channel_cache: ChannelCache) -> None: @@ -163,21 +161,17 @@ class Channel(Context): log.debug( __name__, "Received message with an unexpected synchronization bit" ) - await self._sendAck(sync_bit) + await self._send_ack(sync_bit) raise ThpError("Received message with an unexpected synchronization bit") # 3: Send ACK in response - if __debug__: - log.debug( - __name__, - "Writing ACK message to a channel with id: %d, sync bit: %d", - self.channel_id, - sync_bit, - ) - await self._sendAck(sync_bit) + await self._send_ack(sync_bit) + THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit) - self._handle_valid_message(payload_length, message_length, ctrl_byte) + await self._handle_valid_message( + payload_length, message_length, ctrl_byte, sync_bit + ) print("end handle completed message") def _check_checksum(self, message_length: int): @@ -189,15 +183,16 @@ class Channel(Context): self._todo_clear_buffer() raise ThpError("Invalid checksum, ignoring message.") - def _handle_valid_message( - self, payload_length: int, message_length: int, ctrl_byte: int + async def _handle_valid_message( + self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int ) -> None: state = self.get_channel_state() if __debug__: log.debug(__name__, _state_to_str(state)) if state is ChannelState.TH1: - self._handle_state_TH1(payload_length, message_length) + await self._handle_state_TH1(payload_length, message_length, sync_bit) + return if not _is_ctrl_byte_encrypted_transport(ctrl_byte): self._todo_clear_buffer() @@ -205,11 +200,15 @@ class Channel(Context): if state is ChannelState.ENCRYPTED_TRANSPORT: self._handle_state_ENCRYPTED_TRANSPORT(message_length) + return if state is ChannelState.TH2: - self._handle_state_TH2(message_length) + await self._handle_state_TH2(message_length, sync_bit) + return - def _handle_state_TH1(self, payload_length: int, message_length: int) -> None: + async def _handle_state_TH1( + self, payload_length: int, message_length: int, sync_bit: int + ) -> None: if not _is_ctrl_byte_handshake_init: raise ThpError("Message received is not a handshake init request!") if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH: @@ -218,8 +217,10 @@ class Channel(Context): self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH] ) cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key) - # TODO send ack in response - # TODO send handshake init response message + + await self._send_ack(sync_bit) + + # send handshake init response message loop.schedule( self._write_encrypted_payload_loop( thp_messages.get_handshake_init_response() @@ -228,7 +229,7 @@ class Channel(Context): self.set_channel_state(ChannelState.TH2) return - def _handle_state_TH2(self, message_length: int) -> None: + async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None: print("th2 branche") host_encrypted_static_pubkey = self.buffer[ INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH @@ -243,11 +244,11 @@ class Channel(Context): host_encrypted_static_pubkey, handshake_completion_request_noise_payload, ) # TODO remove - # TODO send ack in response - # TODO send hanshake completion response + + # send hanshake completion response loop.schedule( self._write_encrypted_payload_loop( - thp_messages.get_handshake_init_response() + thp_messages.get_handshake_completion_response() ) ) self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) @@ -288,6 +289,11 @@ class Channel(Context): self.create_new_session(message.passphrase) else: self.create_new_session() + # TODO reuse existing buffer and compute size dynamically + bufferrone = bytearray(2) + message_type: int = thp_messages.get_new_session_message(bufferrone) + + loop.schedule(self._write_encrypted_payload_loop(bufferrone)) except Exception as e: print("Proč??") print(e) @@ -311,13 +317,20 @@ class Channel(Context): self.expected_payload_length = 0 self.is_cont_packet_expected = False - async def _sendAck(self, ack_bit: int) -> None: - ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit) + async def _send_ack(self, ack_bit: int) -> None: + ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit) header = InitHeader( ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH ) chksum = checksum.compute(header.to_bytes()) - await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH) + if __debug__: + log.debug( + __name__, + "Writing ACK message to a channel with id: %d, sync bit: %d", + self.channel_id, + ack_bit, + ) + await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH) def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): if sync_bit == 0: @@ -335,19 +348,22 @@ class Channel(Context): # trezor.crypto.noise.encode(key, payload=self.buffer) - # TODO payload_len should be output from trezor.crypto.noise.encode + # TODO payload_len should be output from trezor.crypto.noise.encode, I guess payload_len = noise_payload_len # + TAG_LENGTH # TODO loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len])) async def _write_encrypted_payload_loop(self, payload: bytes) -> None: print("write loop before while") - payload_len = len(payload) + payload_len = len(payload) + CHECKSUM_LENGTH sync_bit = THP.sync_get_send_bit(self.channel_cache) ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit) header = InitHeader( ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len ) + chksum = checksum.compute(header.to_bytes() + payload) + payload = payload + chksum + # TODO add condition that disallows to write when can_send_message is false THP.sync_set_can_send_message(self.channel_cache, False) while True: @@ -357,7 +373,7 @@ class Channel(Context): " send_sync_bit:", THP.sync_get_send_bit(self.channel_cache), ) - await self._write_encrypted_payload(header, payload, payload_len) + await self._write_payload_to_wire(header, payload, payload_len) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) try: await self.waiting_for_ack_timeout @@ -365,28 +381,29 @@ class Channel(Context): THP.sync_set_send_bit_to_opposite(self.channel_cache) break - async def _write_encrypted_payload( + async def _write_payload_to_wire( self, header: InitHeader, payload: bytes, payload_len: int ): - + print("write payload to wire:") # prepare the report buffer with header data report = bytearray(REPORT_LENGTH) header.pack_to_buffer(report) # write initial report nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) - await self._write_report(report) + await self._write_report_to_wire(report) # if we have more data to write, use continuation reports for it if nwritten < payload_len: header.pack_to_cont_buffer(report) while nwritten < payload_len: nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) - await self._write_report(report) + await self._write_report_to_wire(report) - async def _write_report(self, report: utils.BufferType) -> None: + async def _write_report_to_wire(self, report: utils.BufferType) -> None: while True: await loop.wait(self.iface.iface_num() | io.POLL_WRITE) + printBytes(report) # TODO remove n = self.iface.write(report) if n == len(report): return diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index 731556150..a06706059 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -2,7 +2,7 @@ import ustruct # pyright:ignore[reportMissingModuleSource] from storage.cache_thp import BROADCAST_CHANNEL_ID from trezor import protobuf -from trezor.messages import ThpCreateNewSession +from trezor.messages import ThpCreateNewSession, ThpNewSession from .. import message_handler from ..protocol_common import Message @@ -15,6 +15,9 @@ ACK_MESSAGE = 0x20 _ERROR = 0x41 _CHANNEL_ALLOCATION_RES = 0x40 +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + class InitHeader: format_str = ">BHH" @@ -79,7 +82,21 @@ def get_error_unallocated_channel() -> bytes: def get_handshake_init_response() -> bytes: - return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement + # TODO implement - 32 bytes ephemeral key, 48 bytes encrypted and masked public key, 16 bytes ciphertext of empty string (i.e. noise tag) + return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x40\x41\x42\x43\x44\x45\x46\x47\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15" + + +def get_handshake_completion_response() -> bytes: + return ( + TREZOR_STATE_PAIRED + + b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15" + ) + + +def get_new_session_message(buffer: bytearray) -> int: + msg = ThpNewSession(new_session_id=1) + encoded_msg = protobuf.encode(buffer, msg) + return encoded_msg def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: