From 9757a18a675f0a07d17e1935d2d72ae8b80440dd Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Sat, 25 Jan 2025 22:47:30 +0100 Subject: [PATCH] refactor(python): improve structure of protocol_v2 channel establishment [no changelog] --- .../trezorlib/transport/thp/protocol_v2.py | 115 ++++++++++++------ 1 file changed, 76 insertions(+), 39 deletions(-) diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index fc640dfd9e..945dde6893 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -132,40 +132,78 @@ class ProtocolV2(ProtocolAndChannel): def _establish_new_channel(self) -> None: self.sync_bit_send = 0 self.sync_bit_receive = 0 - # Send channel allocation request - channel_id_request_nonce = os.urandom(8) + + # Generate ephemeral keys + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + self._do_channel_allocation() + + self._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) + + self._do_pairing() + + def _do_channel_allocation(self) -> None: + channel_allocation_nonce = os.urandom(8) + self._send_channel_allocation_request(channel_allocation_nonce) + cid, dp = self._read_channel_allocation_response(channel_allocation_nonce) + self.channel_id = cid + self.device_properties = dp + + def _send_channel_allocation_request(self, nonce: bytes): thp_io.write_payload_to_wire_and_add_checksum( self.transport, MessageHeader.get_channel_allocation_request_header(12), - channel_id_request_nonce, + nonce, ) - # Read channel allocation response + def _read_channel_allocation_response( + self, expected_nonce: bytes + ) -> tuple[int, bytes]: header, payload = self._read_until_valid_crc_check() if not self._is_valid_channel_allocation_response( - header, payload, channel_id_request_nonce + header, payload, expected_nonce ): - # TODO raise exception here, I guess raise Exception("Invalid channel allocation response.") - self.channel_id = int.from_bytes(payload[8:10], "big") - self.device_properties = payload[10:] + channel_id = int.from_bytes(payload[8:10], "big") + device_properties = payload[10:] + return (channel_id, device_properties) - # Send handshake init request + def _do_handshake( + self, host_ephemeral_privkey: bytes, host_ephemeral_pubkey: bytes + ): + self._send_handshake_init_request(host_ephemeral_pubkey) + self._read_ack() + init_response = self._read_handshake_init_response() + + trezor_ephemeral_pubkey = init_response[:32] + encrypted_trezor_static_pubkey = init_response[32:80] + noise_tag = init_response[80:96] + LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) + + # TODO check noise_tag is valid + + ck = self._send_handshake_completion_request( + host_ephemeral_pubkey, + host_ephemeral_privkey, + trezor_ephemeral_pubkey, + encrypted_trezor_static_pubkey, + ) + self._read_ack() + self._read_handshake_completion_response() + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request = 0 + self.nonce_response = 1 + + def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None: ha_init_req_header = MessageHeader(0, self.channel_id, 36) - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) thp_io.write_payload_to_wire_and_add_checksum( self.transport, ha_init_req_header, host_ephemeral_pubkey ) - # Read ACK - header, payload = self._read_until_valid_crc_check() - if not header.is_ack() or len(payload) > 0: - click.echo("Received message is not a valid ACK", err=True) - - # Read handshake init response + def _read_handshake_init_response(self) -> bytes: header, payload = self._read_until_valid_crc_check() self._send_ack_0() @@ -180,15 +218,16 @@ class ProtocolV2(ProtocolAndChannel): "Received message is not a valid handshake init response message", err=True, ) + return payload - trezor_ephemeral_pubkey = payload[:32] - encrypted_trezor_static_pubkey = payload[32:80] - noise_tag = payload[80:96] - - # TODO check noise tag - LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) - - # Prepare and send handshake completion request + def _send_handshake_completion_request( + self, + host_ephemeral_pubkey: bytes, + host_ephemeral_privkey: bytes, + trezor_ephemeral_pubkey: bytes, + encrypted_trezor_static_pubkey: bytes, + credential: bytes | None = None, + ) -> bytes: PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" @@ -232,7 +271,8 @@ class ProtocolV2(ProtocolAndChannel): messages.ThpHandshakeCompletionReqNoisePayload( pairing_methods=[ messages.ThpPairingMethod.NoMethod, - ] + ], + host_pairing_credential=credential, ) ) @@ -252,12 +292,9 @@ class ProtocolV2(ProtocolAndChannel): ha_completion_req_header, encrypted_host_static_pubkey + encrypted_payload, ) + return ck - # Read ACK - header, payload = self._read_until_valid_crc_check() - if not header.is_ack() or len(payload) > 0: - click.echo("Received message is not a valid ACK", err=True) - + def _read_handshake_completion_response(self) -> None: # Read handshake completion response, ignore payload as we do not care about the state header, _ = self._read_until_valid_crc_check() if not header.is_handshake_comp_response(): @@ -267,10 +304,7 @@ class ProtocolV2(ProtocolAndChannel): ) self._send_ack_1() - self.key_request, self.key_response = _hkdf(ck, b"") - self.nonce_request = 0 - self.nonce_response = 1 - + def _do_pairing(self): # Send StartPairingReqest message message = messages.ThpStartPairingRequest() message_type, message_data = self.mapping.encode(message) @@ -278,17 +312,20 @@ class ProtocolV2(ProtocolAndChannel): self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) # Read ACK - header, payload = self._read_until_valid_crc_check() - if not header.is_ack() or len(payload) > 0: - click.echo("Received message is not a valid ACK", err=True) + self._read_ack() - # Read + # Read ThpEndResponse _, msg_type, msg_data = self.read_and_decrypt() maaa = self.mapping.decode(msg_type, msg_data) assert isinstance(maaa, messages.ThpEndResponse) self._has_valid_channel = True + def _read_ack(self): + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + def _send_ack_0(self): LOG.debug("sending ack 0") header = MessageHeader(0x20, self.channel_id, 4)