From 42820e2f9ec11c8a34fd91163d064dadb486c34e Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 14 Mar 2025 20:34:07 +0100 Subject: [PATCH] wip --- core/src/trezor/wire/__init__.py | 5 +- core/src/trezor/wire/thp/channel.py | 18 +- tests/device_tests/thp/connect.py | 48 +++ tests/device_tests/thp/test_handshake.py | 51 +++ tests/device_tests/thp/test_multiple_apps.py | 39 +++ tests/device_tests/thp/test_pairing.py | 338 +++++++++++++++++++ 6 files changed, 494 insertions(+), 5 deletions(-) create mode 100644 tests/device_tests/thp/connect.py create mode 100644 tests/device_tests/thp/test_handshake.py create mode 100644 tests/device_tests/thp/test_multiple_apps.py create mode 100644 tests/device_tests/thp/test_pairing.py diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 287ab3377b..977d2530dd 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -55,6 +55,7 @@ if TYPE_CHECKING: def setup(iface: WireInterface) -> None: """Initialize the wire stack on the provided WireInterface.""" + print(f"SETUP - handle_session on iface {iface}") loop.schedule(handle_session(iface)) @@ -80,8 +81,8 @@ if utils.USE_THP: # Unload modules imported by the workflow. Should not raise. if __debug__: log.debug(__name__, "utils.unimport_end(modules) and loop.clear()") - utils.unimport_end(modules) - loop.clear() + # utils.unimport_end(modules) + # loop.clear() return # pylint: disable=lost-exception else: diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 82664fe189..fd8762f1a1 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -78,6 +78,8 @@ class Channel: # Objects for writing a message to a wire self.transmission_loop: TransmissionLoop | None = None + print("RESETTING TRANSMISSION LOOP", self.transmission_loop) + print("CHANNEL MEM_ADDRESS:", self) self.write_task_spawn: loop.spawn | None = None # Temporary objects @@ -140,7 +142,12 @@ class Channel: self._log("receive packet") self._handle_received_packet(packet) - + print( + "RECEIVE PACKET - DOES TRANSMISSION LOOP EXIST?", + self.transmission_loop is not None, + self.transmission_loop, + ) + print(self) try: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) except WireBufferError: @@ -489,11 +496,16 @@ class Channel: sync_bit = ABP.get_send_seq_bit(self.channel_cache) ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit) header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len) + print("CREATING TRANSMISSION LOOP") self.transmission_loop = TransmissionLoop(self, header, payload) + print("\n\n>>> TRANSSMISSION LOOP START -before", self.transmission_loop) + print("CHANNEL:", self) + print("SELF.WRITE_TASK_SPAWN", self.write_task_spawn is not None) + print("SELF.RETRANSMISSION_LOOP", self.transmission_loop is not None) await self.transmission_loop.start() - + print("\n\n>>> TRANSSMISSION LOOP START -after", self.transmission_loop) ABP.set_send_seq_bit_to_opposite(self.channel_cache) - + print("\n\n>>>>> Set seq bit to opposite", self.transmission_loop) # Let the main loop be restarted and clear loop, if there is no other # workflow and the state is ENCRYPTED_TRANSPORT if self._can_clear_loop(): diff --git a/tests/device_tests/thp/connect.py b/tests/device_tests/thp/connect.py new file mode 100644 index 0000000000..ac2e81d6f9 --- /dev/null +++ b/tests/device_tests/thp/connect.py @@ -0,0 +1,48 @@ +from trezorlib.client import ProtocolV2Channel +from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.messages import ( + ButtonAck, + ButtonRequest, + ThpPairingRequest, + ThpPairingRequestApproved, +) + + +def prepare_protocol_for_handshake(client: Client) -> ProtocolV2Channel: + protocol = client.protocol + assert isinstance(protocol, ProtocolV2Channel) + protocol._reset_sync_bits() + protocol._do_channel_allocation() + return protocol + + +def prepare_protocol_for_pairing( + client: Client, host_static_randomness: bytes | None = None +) -> ProtocolV2Channel: + protocol = prepare_protocol_for_handshake(client) + protocol._do_handshake(host_static_randomness=host_static_randomness) + return protocol + + +def get_encrypted_transport_protocol( + client: Client, host_static_randomness: bytes | None = None +) -> ProtocolV2Channel: + protocol = prepare_protocol_for_pairing( + client, host_static_randomness=host_static_randomness + ) + protocol._do_pairing(client.debug) + return protocol + + +def handle_pairing_request( + client: Client, protocol: ProtocolV2Channel, host_name: str | None = None +) -> None: + protocol._send_message(ThpPairingRequest(host_name=host_name)) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "pairing_request" + + protocol._send_message(ButtonAck()) + + client.debug.press_yes() + + protocol._read_message(ThpPairingRequestApproved) diff --git a/tests/device_tests/thp/test_handshake.py b/tests/device_tests/thp/test_handshake.py new file mode 100644 index 0000000000..b773a7bc5c --- /dev/null +++ b/tests/device_tests/thp/test_handshake.py @@ -0,0 +1,51 @@ +import os +import pytest + +from trezorlib.client import ProtocolV2Channel +from trezorlib.debuglink import TrezorClientDebugLink as Client + + +from .connect import prepare_protocol_for_handshake + + +pytestmark = [pytest.mark.protocol("protocol_v2")] + + +def test_allocate_channel(client: Client) -> None: + assert isinstance(client.protocol, ProtocolV2Channel) + + nonce = os.urandom(8) + + # Use valid nonce + client.protocol._send_channel_allocation_request(nonce) + client.protocol._read_channel_allocation_response(nonce) + + # Expect different nonce + client.protocol._send_channel_allocation_request(nonce) + with pytest.raises(Exception, match="Invalid channel allocation response."): + client.protocol._read_channel_allocation_response( + expected_nonce=b"\xde\xad\xbe\xef\xde\xad\xbe\xef" + ) + client.invalidate() + + +def test_handshake(client: Client) -> None: + protocol = prepare_protocol_for_handshake(client) + + randomness_static = os.urandom(32) + + protocol._do_channel_allocation() + protocol._init_noise( + randomness_static=randomness_static, + ) + protocol._send_handshake_init_request() + protocol._read_ack() + protocol._read_handshake_init_response() + + protocol._send_handshake_completion_request() + protocol._read_ack() + protocol._read_handshake_completion_response() + + # TODO - without pairing, the client is damaged and results in fail of the following test + # so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error + protocol._do_pairing(client.debug) diff --git a/tests/device_tests/thp/test_multiple_apps.py b/tests/device_tests/thp/test_multiple_apps.py new file mode 100644 index 0000000000..adea18a2b9 --- /dev/null +++ b/tests/device_tests/thp/test_multiple_apps.py @@ -0,0 +1,39 @@ +import os +import time +import pytest +from trezorlib.client import ProtocolV2Channel + +from trezorlib.debuglink import TrezorClientDebugLink as Client + +pytestmark = [pytest.mark.protocol("protocol_v2")] + + +def test_multiple_hosts(client: Client) -> None: + assert isinstance(client.protocol, ProtocolV2Channel) + protocol_1 = client.protocol + protocol_2 = ProtocolV2Channel(protocol_1.transport, protocol_1.mapping) + + nonce_1 = os.urandom(8) + nonce_2 = os.urandom(8) + if nonce_1 == nonce_2: + nonce_2 = (int.from_bytes(nonce_1) + 1).to_bytes(8, "big") + protocol_1._send_channel_allocation_request(nonce_1) + protocol_1.channel_id, protocol_1.device_properties = ( + protocol_1._read_channel_allocation_response(nonce_1) + ) + protocol_2._send_channel_allocation_request(nonce_2) + protocol_2.channel_id, protocol_2.device_properties = ( + protocol_2._read_channel_allocation_response(nonce_2) + ) + + protocol_1._init_noise() + protocol_2._init_noise() + + protocol_1._send_handshake_init_request() + protocol_1._read_ack() + protocol_1._read_handshake_init_response() + + time.sleep(0.2) # To pass LOCK_TIME + protocol_2._send_handshake_init_request() + protocol_2._read_ack() + # protocol_2._read_handshake_init_response() diff --git a/tests/device_tests/thp/test_pairing.py b/tests/device_tests/thp/test_pairing.py new file mode 100644 index 0000000000..725c6908d4 --- /dev/null +++ b/tests/device_tests/thp/test_pairing.py @@ -0,0 +1,338 @@ +import os +import typing as t +from hashlib import sha256 + +import pytest +import typing_extensions as tx + +from tests.common import get_test_address +from trezorlib import exceptions, protobuf +from trezorlib.client import ProtocolV2Channel +from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.messages import ( + ButtonAck, + ButtonRequest, + ThpCodeEntryChallenge, + ThpCodeEntryCommitment, + ThpCodeEntryCpaceHostTag, + ThpCodeEntryCpaceTrezor, + ThpCodeEntrySecret, + ThpCredentialRequest, + ThpCredentialResponse, + ThpEndRequest, + ThpEndResponse, + ThpNfcTagHost, + ThpNfcTagTrezor, + ThpPairingMethod, + ThpPairingPreparationsFinished, + ThpQrCodeSecret, + ThpQrCodeTag, + ThpSelectMethod, +) +from trezorlib.transport.thp import curve25519 +from trezorlib.transport.thp.cpace import Cpace + +from .connect import ( + prepare_protocol_for_handshake, + prepare_protocol_for_pairing, + get_encrypted_transport_protocol, + handle_pairing_request, +) + +if t.TYPE_CHECKING: + P = tx.ParamSpec("P") + +MT = t.TypeVar("MT", bound=protobuf.MessageType) + +pytestmark = [pytest.mark.protocol("protocol_v2")] + + +def test_pairing_qr_code(client: Client) -> None: + protocol = prepare_protocol_for_pairing(client) + handle_pairing_request(client, protocol, "TestTrezor QrCode") + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) + ) + protocol._read_message(ThpPairingPreparationsFinished) + + # QR Code shown + + # Read code from "Trezor's display" using debuglink + + pairing_info = client.debug.pairing_info( + thp_channel_id=protocol.channel_id.to_bytes(2, "big") + ) + code = pairing_info.code_qr_code + + # Compute tag for response + sha_ctx = sha256(protocol.handshake_hash) + sha_ctx.update(code) + tag = sha_ctx.digest() + + protocol._send_message(ThpQrCodeTag(tag=tag)) + + secret_msg = protocol._read_message(ThpQrCodeSecret) + + # Check that the `code` was derived from the revealed secret + sha_ctx = sha256(ThpPairingMethod.QrCode.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(secret_msg.secret) + computed_code = sha_ctx.digest()[:16] + assert code == computed_code + + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + protocol._has_valid_channel = True + + +def test_pairing_code_entry(client: Client) -> None: + protocol = prepare_protocol_for_pairing(client) + + handle_pairing_request(client, protocol, "TestTrezor CodeEntry") + + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) + ) + + commitment_msg = protocol._read_message(ThpCodeEntryCommitment) + commitment = commitment_msg.commitment + + challenge = os.urandom(16) + protocol._send_message(ThpCodeEntryChallenge(challenge=challenge)) + + cpace_trezor = protocol._read_message(ThpCodeEntryCpaceTrezor) + cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key + + # Code Entry code shown + + pairing_info = client.debug.pairing_info( + thp_channel_id=protocol.channel_id.to_bytes(2, "big") + ) + code = pairing_info.code_entry_code + + cpace = Cpace(handshake_hash=protocol.handshake_hash) + cpace.random_bytes = os.urandom + cpace.generate_keys_and_secret(code.to_bytes(6, "big"), cpace_trezor_public_key) + sha_ctx = sha256(cpace.shared_secret) + tag = sha_ctx.digest() + + protocol._send_message( + ThpCodeEntryCpaceHostTag( + cpace_host_public_key=cpace.host_public_key, + tag=tag, + ) + ) + + secret_msg = protocol._read_message(ThpCodeEntrySecret) + + # Check `commitment` and `code` + sha_ctx = sha256(secret_msg.secret) + computed_commitment = sha_ctx.digest() + assert commitment == computed_commitment + + sha_ctx = sha256(ThpPairingMethod.CodeEntry.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(secret_msg.secret) + sha_ctx.update(challenge) + code_hash = sha_ctx.digest() + computed_code = int.from_bytes(code_hash, "big") % 1000000 + assert code == computed_code + + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + protocol._has_valid_channel = True + + +def test_pairing_nfc(client: Client) -> None: + protocol = prepare_protocol_for_pairing(client) + + _nfc_pairing(client, protocol) + + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + protocol._has_valid_channel = True + + +def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None: + + handle_pairing_request(client, protocol, "TestTrezor NfcPairing") + + protocol._send_message( + ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC) + ) + protocol._read_message(ThpPairingPreparationsFinished) + + # NFC screen shown + + nfc_secret_host = os.urandom(16) + # Read `nfc_secret` and `handshake_hash` from Trezor using debuglink + pairing_info = client.debug.pairing_info( + thp_channel_id=protocol.channel_id.to_bytes(2, "big"), + handshake_hash=protocol.handshake_hash, + nfc_secret_host=nfc_secret_host, + ) + handshake_hash_trezor = pairing_info.handshake_hash + nfc_secret_trezor = pairing_info.nfc_secret_trezor + + assert handshake_hash_trezor[:16] == protocol.handshake_hash[:16] + + # Compute tag for response + sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(nfc_secret_trezor) + tag_host = sha_ctx.digest() + + protocol._send_message(ThpNfcTagHost(tag=tag_host)) + + tag_trezor_msg = protocol._read_message(ThpNfcTagTrezor) + + # Check that the `code` was derived from the revealed secret + sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) + sha_ctx.update(protocol.handshake_hash) + sha_ctx.update(nfc_secret_host) + computed_tag = sha_ctx.digest() + assert tag_trezor_msg.tag == computed_tag + + +def test_credential_phase(client: Client) -> None: + protocol = prepare_protocol_for_pairing(client) + _nfc_pairing(client, protocol) + + # Request credential with confirmation after pairing + randomness_static = os.urandom(32) + host_static_privkey = curve25519.get_private_key(randomness_static) + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False) + ) + credential_response = protocol._read_message(ThpCredentialResponse) + + assert credential_response.credential is not None + credential = credential_response.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using credential with confirmation + protocol = prepare_protocol_for_handshake(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential, randomness_static) + protocol._send_message(ThpEndRequest()) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + protocol._read_message(ThpEndResponse) + + # Delete channel from the device by sending badly encrypted message + # This is done to prevent channel replacement and trigerring of autoconnect false -> true + protocol._noise.noise_protocol.cipher_state_encrypt.n = 250 + + protocol._send_message(ButtonAck()) + with pytest.raises(Exception) as e: + protocol.read(1) + assert e.value.args[0] == "Received ThpError: DECRYPTION FAILED" + + # Connect using credential with confirmation and ask for autoconnect credential. + protocol = prepare_protocol_for_handshake(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential, randomness_static) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True) + ) + # Connection confirmation dialog is shown. (Channel replacement is not triggered.) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + # Autoconnect issuance confirmation dialog is shown. + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "autoconnect_credential_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + # Autoconnect credential is received + credential_response_2 = protocol._read_message(ThpCredentialResponse) + assert credential_response_2.credential is not None + credential_auto = credential_response_2.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using credential with confirmation + protocol = prepare_protocol_for_handshake(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential, randomness_static) + # Confirmation dialog is not shown as channel in ENCRYPTED TRANSPORT state with the same + # host static public key is still available in Trezor's cache. (Channel replacement is triggered.) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using autoconnect credential + protocol = prepare_protocol_for_handshake(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential_auto, randomness_static) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Delete channel from the device by sending badly encrypted message + # This is done to prevent channel replacement and trigerring of autoconnect false -> true + protocol._noise.noise_protocol.cipher_state_encrypt.n = 100 + + protocol._send_message(ButtonAck()) + with pytest.raises(Exception) as e: + protocol.read(1) + assert e.value.args[0] == "Received ThpError: DECRYPTION FAILED" + + # Connect using autoconnect credential - should work the same as above + protocol = prepare_protocol_for_handshake(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential_auto, randomness_static) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + +@pytest.mark.setup_client(passphrase=True) +def test_channel_replacement(client: Client) -> None: + assert client.features.passphrase_protection is True + + host_static_randomness = os.urandom(32) + host_static_randomness_2 = os.urandom(32) + host_static_privkey = curve25519.get_private_key(host_static_randomness) + host_static_privkey_2 = curve25519.get_private_key(host_static_randomness_2) + + assert host_static_privkey != host_static_privkey_2 + + client.protocol = get_encrypted_transport_protocol(client, host_static_randomness) + + session = client.get_session(passphrase="TREZOR", session_id=b"\x10") + address = get_test_address(session) + + session_2 = client.get_session(passphrase="ROZERT", session_id=b"\x20") + address_2 = get_test_address(session_2) + assert address != address_2 + + # create new channel using the same host_static_privkey + client.protocol = get_encrypted_transport_protocol(client, host_static_randomness) + session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30") + address_3 = get_test_address(session_3) + assert address_3 != address_2 + + # test address on regenerated channel + new_address = get_test_address(session) + assert address == new_address + new_address_3 = get_test_address(session_3) + assert address_3 == new_address_3 + + # create new channel using different host_static_privkey + client.protocol = get_encrypted_transport_protocol(client, host_static_randomness_2) + with pytest.raises(exceptions.TrezorFailure) as e_1: + _ = get_test_address(session) + assert str(e_1.value.message) == "Invalid session" + + with pytest.raises(exceptions.TrezorFailure) as e_2: + _ = get_test_address(session_3) + assert str(e_2.value.message) == "Invalid session" + + session_4 = client.get_session(passphrase="TREZOR", session_id=b"\x40") + super_new_address = get_test_address(session_4) + assert address == super_new_address