1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-24 00:48:45 +00:00

chore: handle cancel in pairing and credential flows

This commit is contained in:
M1nd3r 2025-04-15 16:06:41 +02:00
parent 4ec92c6b49
commit beb6c00b19
3 changed files with 96 additions and 7 deletions

View File

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING
from ubinascii import hexlify from ubinascii import hexlify
import trezorui_api import trezorui_api
from trezor import loop, protobuf, workflow from trezor import loop, protobuf, utils, workflow
from trezor.enums import ButtonRequestType from trezor.enums import ButtonRequestType
from trezor.wire import context, message_handler, protocol_common from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException from trezor.wire.context import UnexpectedMessageException
@ -104,8 +104,12 @@ class PairingContext(Context):
) )
message: Message = await self.incoming_message message: Message = await self.incoming_message
if message.type not in expected_types: if message.type not in expected_types:
from trezor.messages import Cancel
if message.type == Cancel.MESSAGE_WIRE_TYPE:
raise ActionCancelled
raise UnexpectedMessageException(message) raise UnexpectedMessageException(message)
if expected_type is None: if expected_type is None:
@ -192,18 +196,20 @@ class PairingContext(Context):
from trezor.ui.layouts.common import interact from trezor.ui.layouts.common import interact
if not device_name: if not device_name:
action_string = f"Allow {self.host_name} to pair with this Trezor?" action_string = f"Allow {self.host_name} to connect with this Trezor?"
else: else:
action_string = ( action_string = (
f"Allow {self.host_name} on {device_name} to pair with this Trezor?" f"Allow {self.host_name} on {device_name} to connect with this Trezor?"
) )
await interact( result = await interact(
trezorui_api.confirm_action( trezorui_api.confirm_action(
title="Connection dialog", action=action_string, description=None title="Connection dialog", action=action_string, description=None
), ),
br_name="thp_connection_request", br_name="thp_connection_request",
br_code=ButtonRequestType.Other, br_code=ButtonRequestType.Other,
) )
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Result of connection dialog %s", str(result))
async def show_autoconnect_credential_confirmation_screen(self) -> None: async def show_autoconnect_credential_confirmation_screen(self) -> None:
from trezor.ui.layouts.common import interact from trezor.ui.layouts.common import interact

View File

@ -17,10 +17,14 @@ def prepare_protocol_for_handshake(client: Client) -> ProtocolV2Channel:
def prepare_protocol_for_pairing( def prepare_protocol_for_pairing(
client: Client, host_static_randomness: bytes | None = None client: Client,
host_static_randomness: bytes | None = None,
credential: bytes | None = None,
) -> ProtocolV2Channel: ) -> ProtocolV2Channel:
protocol = prepare_protocol_for_handshake(client) protocol = prepare_protocol_for_handshake(client)
protocol._do_handshake(host_static_randomness=host_static_randomness) protocol._do_handshake(
credential=credential, host_static_randomness=host_static_randomness
)
return protocol return protocol

View File

@ -229,6 +229,85 @@ def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None:
assert tag_trezor_msg.tag == computed_tag assert tag_trezor_msg.tag == computed_tag
def test_connection_confirmation_cancel(client: Client) -> None:
protocol = prepare_protocol_for_pairing(client)
_nfc_pairing(client, protocol)
# Request credential with confirmation after pairing
randomness_static = os.urandom(32)
host_static_privkey = curve25519.get_private_key(randomness_static)
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
protocol._send_message(
ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False)
)
credential_response = protocol._read_message(ThpCredentialResponse)
assert credential_response.credential is not None
credential = credential_response.credential
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
# Connect using credential with confirmation
protocol = prepare_protocol_for_pairing(
client=client, host_static_randomness=randomness_static, credential=credential
)
protocol._send_message(ThpEndRequest())
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "thp_connection_request"
protocol._send_message(Cancel())
failure = protocol._read_message(Failure)
assert failure.code == FailureType.ActionCancelled
time.sleep(0.2) # TODO fix this behavior
protocol = prepare_protocol_for_pairing(
client=client, host_static_randomness=randomness_static, credential=credential
)
protocol._send_message(ThpEndRequest())
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "thp_connection_request"
protocol._send_message(ButtonAck())
client.debug.press_yes()
protocol._read_message(ThpEndResponse)
def test_autoconnect_credential_request_cancel(client: Client) -> None:
protocol = prepare_protocol_for_pairing(client)
_nfc_pairing(client, protocol)
# Request credential with confirmation after pairing
randomness_static = os.urandom(32)
host_static_privkey = curve25519.get_private_key(randomness_static)
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
protocol._send_message(
ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False)
)
credential_response = protocol._read_message(ThpCredentialResponse)
assert credential_response.credential is not None
credential = credential_response.credential
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
# Connect using credential with confirmation and request autoconnect
protocol = prepare_protocol_for_pairing(
client=client, host_static_randomness=randomness_static, credential=credential
)
protocol._send_message(
ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True)
)
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "thp_connection_request"
protocol._send_message(ButtonAck())
client.debug.press_yes()
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "thp_autoconnect_credential_request"
protocol._send_message(Cancel())
failure = protocol._read_message(Failure)
assert failure.code == FailureType.ActionCancelled
def test_credential_phase(client: Client) -> None: def test_credential_phase(client: Client) -> None:
protocol = prepare_protocol_for_pairing(client) protocol = prepare_protocol_for_pairing(client)
_nfc_pairing(client, protocol) _nfc_pairing(client, protocol)