diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 6b076acbd1..4463f6b130 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from ubinascii import hexlify import trezorui_api -from trezor import loop, protobuf, workflow +from trezor import loop, protobuf, utils, workflow from trezor.enums import ButtonRequestType from trezor.wire import context, message_handler, protocol_common from trezor.wire.context import UnexpectedMessageException @@ -104,8 +104,12 @@ class PairingContext(Context): ) message: Message = await self.incoming_message - if message.type not in expected_types: + from trezor.messages import Cancel + + if message.type == Cancel.MESSAGE_WIRE_TYPE: + raise ActionCancelled + raise UnexpectedMessageException(message) if expected_type is None: @@ -192,18 +196,20 @@ class PairingContext(Context): from trezor.ui.layouts.common import interact if not device_name: - action_string = f"Allow {self.host_name} to pair with this Trezor?" + action_string = f"Allow {self.host_name} to connect with this Trezor?" else: action_string = ( - f"Allow {self.host_name} on {device_name} to pair with this Trezor?" + f"Allow {self.host_name} on {device_name} to connect with this Trezor?" ) - await interact( + result = await interact( trezorui_api.confirm_action( title="Connection dialog", action=action_string, description=None ), br_name="thp_connection_request", br_code=ButtonRequestType.Other, ) + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + log.debug(__name__, "Result of connection dialog %s", str(result)) async def show_autoconnect_credential_confirmation_screen(self) -> None: from trezor.ui.layouts.common import interact diff --git a/tests/device_tests/thp/connect.py b/tests/device_tests/thp/connect.py index 26678de128..7279370771 100644 --- a/tests/device_tests/thp/connect.py +++ b/tests/device_tests/thp/connect.py @@ -17,10 +17,14 @@ def prepare_protocol_for_handshake(client: Client) -> ProtocolV2Channel: def prepare_protocol_for_pairing( - client: Client, host_static_randomness: bytes | None = None + client: Client, + host_static_randomness: bytes | None = None, + credential: bytes | None = None, ) -> ProtocolV2Channel: protocol = prepare_protocol_for_handshake(client) - protocol._do_handshake(host_static_randomness=host_static_randomness) + protocol._do_handshake( + credential=credential, host_static_randomness=host_static_randomness + ) return protocol diff --git a/tests/device_tests/thp/test_pairing.py b/tests/device_tests/thp/test_pairing.py index 39bf7c0778..f00f59f95a 100644 --- a/tests/device_tests/thp/test_pairing.py +++ b/tests/device_tests/thp/test_pairing.py @@ -229,6 +229,85 @@ def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None: assert tag_trezor_msg.tag == computed_tag +def test_connection_confirmation_cancel(client: Client) -> None: + protocol = prepare_protocol_for_pairing(client) + _nfc_pairing(client, protocol) + + # Request credential with confirmation after pairing + randomness_static = os.urandom(32) + host_static_privkey = curve25519.get_private_key(randomness_static) + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False) + ) + credential_response = protocol._read_message(ThpCredentialResponse) + + assert credential_response.credential is not None + credential = credential_response.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using credential with confirmation + protocol = prepare_protocol_for_pairing( + client=client, host_static_randomness=randomness_static, credential=credential + ) + protocol._send_message(ThpEndRequest()) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "thp_connection_request" + protocol._send_message(Cancel()) + failure = protocol._read_message(Failure) + + assert failure.code == FailureType.ActionCancelled + + time.sleep(0.2) # TODO fix this behavior + protocol = prepare_protocol_for_pairing( + client=client, host_static_randomness=randomness_static, credential=credential + ) + protocol._send_message(ThpEndRequest()) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "thp_connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + protocol._read_message(ThpEndResponse) + + +def test_autoconnect_credential_request_cancel(client: Client) -> None: + protocol = prepare_protocol_for_pairing(client) + _nfc_pairing(client, protocol) + + # Request credential with confirmation after pairing + randomness_static = os.urandom(32) + host_static_privkey = curve25519.get_private_key(randomness_static) + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False) + ) + credential_response = protocol._read_message(ThpCredentialResponse) + + assert credential_response.credential is not None + credential = credential_response.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using credential with confirmation and request autoconnect + protocol = prepare_protocol_for_pairing( + client=client, host_static_randomness=randomness_static, credential=credential + ) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True) + ) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "thp_connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "thp_autoconnect_credential_request" + protocol._send_message(Cancel()) + failure = protocol._read_message(Failure) + + assert failure.code == FailureType.ActionCancelled + + def test_credential_phase(client: Client) -> None: protocol = prepare_protocol_for_pairing(client) _nfc_pairing(client, protocol)