From f8f2bfa5358c1fdc8493c311ebcdfb8c89c512a6 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Thu, 30 Jan 2025 19:41:44 +0100 Subject: [PATCH] refactor(python): improve protocolV2 and related tests [no changelog] --- .../trezorlib/transport/thp/protocol_v2.py | 87 ++++----- tests/device_tests/thp/test_thp.py | 174 +++++++----------- 2 files changed, 98 insertions(+), 163 deletions(-) diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 951065f42b..1604ae64ea 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -11,7 +11,7 @@ from enum import IntEnum import click from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from ... import exceptions, messages +from ... import exceptions, messages, protobuf from ...mapping import ProtobufMapping from .. import Transport from ..thp import checksum, curve25519, thp_io @@ -28,6 +28,7 @@ MANAGEMENT_SESSION_ID: int = 0 if t.TYPE_CHECKING: from ...debuglink import DebugLink +MT = t.TypeVar("MT", bound=protobuf.MessageType) def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: @@ -135,20 +136,31 @@ class ProtocolV2(ProtocolAndChannel): raise exceptions.TrezorException("Unexpected response to GetFeatures") self._features = features + def _send_message( + self, + message: protobuf.MessageType, + session_id: int = MANAGEMENT_SESSION_ID, + ): + message_type, message_data = self.mapping.encode(message) + self._encrypt_and_write(session_id, message_type, message_data) + self._read_ack() + + def _read_message(self, message_type: type[MT]) -> MT: + _, msg_type, msg_data = self.read_and_decrypt() + msg = self.mapping.decode(msg_type, msg_data) + assert isinstance(msg, message_type) + return msg + def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None: + self._reset_sync_bits() + self._do_channel_allocation() + self._do_handshake() + self._do_pairing(helper_debug) + + def _reset_sync_bits(self) -> None: self.sync_bit_send = 0 self.sync_bit_receive = 0 - # Generate ephemeral keys - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) - - self._do_channel_allocation() - - self._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) - - self._do_pairing(helper_debug) - def _do_channel_allocation(self) -> None: channel_allocation_nonce = os.urandom(8) self._send_channel_allocation_request(channel_allocation_nonce) @@ -176,9 +188,10 @@ class ProtocolV2(ProtocolAndChannel): device_properties = payload[10:] return (channel_id, device_properties) - def _do_handshake( - self, host_ephemeral_privkey: bytes, host_ephemeral_pubkey: bytes - ): + def _do_handshake(self): + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + self._send_handshake_init_request(host_ephemeral_pubkey) self._read_ack() init_response = self._read_handshake_init_response() @@ -309,49 +322,21 @@ class ProtocolV2(ProtocolAndChannel): self._send_ack_1() def _do_pairing(self, helper_debug: DebugLink | None): - # Send StartPairingReqest message - message = messages.ThpPairingRequest() - message_type, message_data = self.mapping.encode(message) - self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - - # Read ACK - self._read_ack() - - # Read button request - _, msg_type, msg_data = self.read_and_decrypt() - maaa = self.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, messages.ButtonRequest) - - # Send button ACK - message = messages.ButtonAck() - message_type, message_data = self.mapping.encode(message) - - self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - self._read_ack() + self._send_message(messages.ThpPairingRequest()) + self._read_message(messages.ButtonRequest) + self._send_message(messages.ButtonAck()) if helper_debug is not None: helper_debug.press_yes() - # Read PairingRequestApproved - _, msg_type, msg_data = self.read_and_decrypt() - maaa = self.mapping.decode(msg_type, msg_data) - - assert isinstance(maaa, messages.ThpPairingRequestApproved) - - message = messages.ThpSelectMethod( - selected_pairing_method=messages.ThpPairingMethod.SkipPairing + self._read_message(messages.ThpPairingRequestApproved) + self._send_message( + messages.ThpSelectMethod( + selected_pairing_method=messages.ThpPairingMethod.SkipPairing + ) ) - message_type, message_data = self.mapping.encode(message) - - self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) - # Read ACK - self._read_ack() - - # Read ThpEndResponse - _, msg_type, msg_data = self.read_and_decrypt() - maaa = self.mapping.decode(msg_type, msg_data) - assert isinstance(maaa, messages.ThpEndResponse) + self._read_message(messages.ThpEndResponse) self._has_valid_channel = True diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_thp.py index 72bfe3a75c..6b4569ca5d 100644 --- a/tests/device_tests/thp/test_thp.py +++ b/tests/device_tests/thp/test_thp.py @@ -31,7 +31,7 @@ from trezorlib.messages import ( ) from trezorlib.transport.thp import curve25519 from trezorlib.transport.thp.cpace import Cpace -from trezorlib.transport.thp.protocol_v2 import MANAGEMENT_SESSION_ID, _hkdf +from trezorlib.transport.thp.protocol_v2 import _hkdf if t.TYPE_CHECKING: P = tx.ParamSpec("P") @@ -41,22 +41,24 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType) pytestmark = [pytest.mark.protocol("protocol_v2")] -protocol: ProtocolV2 - - -def _prepare_protocol(client: Client): - global protocol +def _prepare_protocol(client: Client) -> ProtocolV2: protocol = client.protocol - protocol.sync_bit_send = 0 - protocol.sync_bit_receive = 0 + assert isinstance(protocol, ProtocolV2) + protocol._reset_sync_bits() + return protocol + + +def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2: + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake() + return protocol def test_allocate_channel(client: Client) -> None: - global protocol - _prepare_protocol(client) + protocol = _prepare_protocol(client) - # protocol: ProtocolV2 = client.protocol - nonce = b"\x1A\x2B\x3B\x4A\x5C\x6D\x7E\x8F" + nonce = random.randbytes(8) # Use valid nonce protocol._send_channel_allocation_request(nonce) @@ -72,9 +74,7 @@ def test_allocate_channel(client: Client) -> None: def test_handshake(client: Client) -> None: - global protocol - _prepare_protocol(client) - # protocol: ProtocolV2 = client.protocol + protocol = _prepare_protocol(client) host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) @@ -110,53 +110,24 @@ def test_handshake(client: Client) -> None: assert noise_tag is not None -def _send_message( - message: MT, - session_id: int = MANAGEMENT_SESSION_ID, -): - global protocol - message_type, message_data = protocol.mapping.encode(message) - protocol._encrypt_and_write(session_id, message_type, message_data) - protocol._read_ack() - - -def _read_message(message_type: type[MT]) -> MT: - global protocol - _, msg_type, msg_data = protocol.read_and_decrypt() - msg = protocol.mapping.decode(msg_type, msg_data) - assert isinstance(msg, message_type) - return msg - - def test_pairing_qr_code(client: Client) -> None: - global protocol - _prepare_protocol(client) + protocol = _prepare_protocol_for_pairing(client) - # Generate ephemeral keys - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) - - protocol._do_channel_allocation() - - protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) - - _send_message(ThpPairingRequest()) - - _read_message(ButtonRequest) - - _send_message(ButtonAck()) + protocol._send_message(ThpPairingRequest()) + protocol._read_message(ButtonRequest) + protocol._send_message(ButtonAck()) client.debug.press_yes() - _read_message(ThpPairingRequestApproved) - - _send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)) - - _read_message(ThpPairingPreparationsFinished) + protocol._read_message(ThpPairingRequestApproved) + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) + ) + protocol._read_message(ThpPairingPreparationsFinished) # QR Code shown - _read_message(ButtonRequest) - _send_message(ButtonAck()) + protocol._read_message(ButtonRequest) + protocol._send_message(ButtonAck()) # Read code from "Trezor's display" using debuglink @@ -170,9 +141,9 @@ def test_pairing_qr_code(client: Client) -> None: sha_ctx.update(code) tag = sha_ctx.digest() - _send_message(ThpQrCodeTag(tag=tag)) + protocol._send_message(ThpQrCodeTag(tag=tag)) - secret_msg = _read_message(ThpQrCodeSecret) + secret_msg = protocol._read_message(ThpQrCodeSecret) # Check that the `code` was derived from the revealed secret sha_ctx = sha256(ThpPairingMethod.QrCode.to_bytes(1, "big")) @@ -181,48 +152,38 @@ def test_pairing_qr_code(client: Client) -> None: computed_code = sha_ctx.digest()[:16] assert code == computed_code - _send_message(ThpEndRequest()) - _read_message(ThpEndResponse) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) protocol._has_valid_channel = True def test_pairing_code_entry(client: Client) -> None: - global protocol - _prepare_protocol(client) + protocol = _prepare_protocol_for_pairing(client) - # Generate ephemeral keys - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) - - protocol._do_channel_allocation() - - protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) - - _send_message(ThpPairingRequest()) - - _read_message(ButtonRequest) - - _send_message(ButtonAck()) + protocol._send_message(ThpPairingRequest()) + protocol._read_message(ButtonRequest) + protocol._send_message(ButtonAck()) client.debug.press_yes() - _read_message(ThpPairingRequestApproved) + protocol._read_message(ThpPairingRequestApproved) + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) + ) - _send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)) - - commitment_msg = _read_message(ThpCodeEntryCommitment) + commitment_msg = protocol._read_message(ThpCodeEntryCommitment) commitment = commitment_msg.commitment challenge = random.randbytes(16) - _send_message(ThpCodeEntryChallenge(challenge=challenge)) + protocol._send_message(ThpCodeEntryChallenge(challenge=challenge)) - cpace_trezor = _read_message(ThpCodeEntryCpaceTrezor) + cpace_trezor = protocol._read_message(ThpCodeEntryCpaceTrezor) cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key # Code Entry code shown - _read_message(ButtonRequest) - _send_message(ButtonAck()) + protocol._read_message(ButtonRequest) + protocol._send_message(ButtonAck()) pairing_info = client.debug.pairing_info( thp_channel_id=protocol.channel_id.to_bytes(2, "big") @@ -235,14 +196,14 @@ def test_pairing_code_entry(client: Client) -> None: sha_ctx = sha256(cpace.shared_secret) tag = sha_ctx.digest() - _send_message( + protocol._send_message( ThpCodeEntryCpaceHostTag( cpace_host_public_key=cpace.host_public_key, tag=tag, ) ) - secret_msg = _read_message(ThpCodeEntrySecret) + secret_msg = protocol._read_message(ThpCodeEntrySecret) # Check `commitment` and `code` sha_ctx = sha256(secret_msg.secret) @@ -257,41 +218,30 @@ def test_pairing_code_entry(client: Client) -> None: computed_code = int.from_bytes(code_hash, "big") % 1000000 assert code == computed_code - _send_message(ThpEndRequest()) - _read_message(ThpEndResponse) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) protocol._has_valid_channel = True def test_pairing_nfc(client: Client) -> None: - global protocol - _prepare_protocol(client) + protocol = _prepare_protocol_for_pairing(client) - # Generate ephemeral keys - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) - - protocol._do_channel_allocation() - - protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) - - _send_message(ThpPairingRequest()) - - _read_message(ButtonRequest) - - _send_message(ButtonAck()) + protocol._send_message(ThpPairingRequest()) + protocol._read_message(ButtonRequest) + protocol._send_message(ButtonAck()) client.debug.press_yes() - _read_message(ThpPairingRequestApproved) - - _send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)) - - _read_message(ThpPairingPreparationsFinished) + protocol._read_message(ThpPairingRequestApproved) + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC) + ) + protocol._read_message(ThpPairingPreparationsFinished) # NFC screen shown - _read_message(ButtonRequest) - _send_message(ButtonAck()) + protocol._read_message(ButtonRequest) + protocol._send_message(ButtonAck()) nfc_secret_host = random.randbytes(16) # Read `nfc_secret` and `handshake_hash` from Trezor using debuglink @@ -311,9 +261,9 @@ def test_pairing_nfc(client: Client) -> None: sha_ctx.update(nfc_secret_trezor) tag_host = sha_ctx.digest() - _send_message(ThpNfcTagHost(tag=tag_host)) + protocol._send_message(ThpNfcTagHost(tag=tag_host)) - tag_trezor_msg = _read_message(ThpNfcTagTrezor) + tag_trezor_msg = protocol._read_message(ThpNfcTagTrezor) # Check that the `code` was derived from the revealed secret sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) @@ -322,7 +272,7 @@ def test_pairing_nfc(client: Client) -> None: computed_tag = sha_ctx.digest() assert tag_trezor_msg.tag == computed_tag - _send_message(ThpEndRequest()) - _read_message(ThpEndResponse) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) protocol._has_valid_channel = True