From 1e5decb2e488adbfd914d9accc0342df4c3ad0a3 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 31 Jan 2025 16:11:48 +0100 Subject: [PATCH] refactor: improve structure of code for showing pairing screens and storing pairing data [no changelog] --- core/src/apps/debug/__init__.py | 4 +- core/src/apps/thp/pairing.py | 40 ++--- core/src/trezor/wire/thp/pairing_context.py | 159 +++++++++----------- 3 files changed, 80 insertions(+), 123 deletions(-) diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 0f34c5ac34..93a0342724 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -299,8 +299,8 @@ if __debug__: return DebugLinkPairingInfo( channel_id=ctx.channel_id, handshake_hash=ctx.channel_ctx.get_handshake_hash(), - code_entry_code=ctx.display_data.code_code_entry, - code_qr_code=ctx.display_data.code_qr_code, + code_entry_code=ctx.code_code_entry, + code_qr_code=ctx.code_qr_code, nfc_secret_trezor=ctx.nfc_secret, ) diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index a195f5420f..0031c93420 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from ubinascii import hexlify -from trezor import loop, protobuf +from trezor import protobuf from trezor.crypto import random from trezor.crypto.hashlib import sha256 from trezor.enums import ThpMessageType, ThpPairingMethod @@ -27,7 +27,7 @@ from trezor.messages import ( ) from trezor.wire import message_handler from trezor.wire.context import UnexpectedMessageException -from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage +from trezor.wire.errors import SilentError, UnexpectedMessage from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods from trezor.wire.thp.pairing_context import PairingContext @@ -37,7 +37,7 @@ if __debug__: from trezor import log if TYPE_CHECKING: - from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple + from typing import Any, Callable, Concatenate, ParamSpec, Tuple P = ParamSpec("P") FuncWithContext = Callable[Concatenate[PairingContext, P], Any] @@ -186,22 +186,6 @@ async def _prepare_pairing(ctx: PairingContext) -> None: raise Exception() # TODO unknown pairing method -async def show_display_data( - ctx: PairingContext, expected_types: Container[int] = () -) -> type[protobuf.MessageType]: - from trezorui_api import CANCELLED - - read_task = ctx.read(expected_types) - cancel_task = ctx.display_data.get_display_layout() - race = loop.race(read_task, cancel_task.get_result()) - result: type[protobuf.MessageType] = await race - - if result is CANCELLED: - raise ActionCancelled - - return result - - @check_state_and_log(ChannelState.TP1) async def _handle_code_entry_is_selected(ctx: PairingContext) -> None: if ctx.code_entry_secret is None: @@ -231,16 +215,12 @@ async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None sha_ctx.update(ctx.code_entry_secret) sha_ctx.update(challenge_message.challenge) code_code_entry_hash = sha_ctx.digest() - ctx.display_data.code_code_entry = ( - int.from_bytes(code_code_entry_hash, "big") % 1000000 - ) + ctx.code_code_entry = int.from_bytes(code_code_entry_hash, "big") % 1000000 ctx.cpace = Cpace( ctx.channel_ctx.get_handshake_hash(), ) - assert ctx.display_data.code_code_entry is not None - ctx.cpace.generate_keys_and_secret( - ctx.display_data.code_code_entry.to_bytes(6, "big") - ) + assert ctx.code_code_entry is not None + ctx.cpace.generate_keys_and_secret(ctx.code_code_entry.to_bytes(6, "big")) await ctx.write_force( ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key) ) @@ -260,7 +240,7 @@ async def _handle_qr_code_is_selected(ctx: PairingContext) -> None: sha_ctx.update(ctx.channel_ctx.get_handshake_hash()) sha_ctx.update(ctx.qr_code_secret) - ctx.display_data.code_qr_code = sha_ctx.digest()[:16] + ctx.code_qr_code = sha_ctx.digest()[:16] await ctx.write_force(ThpPairingPreparationsFinished()) @@ -317,9 +297,9 @@ async def _handle_qr_code_tag( ) -> protobuf.MessageType: if TYPE_CHECKING: assert isinstance(message, ThpQrCodeTag) - assert ctx.display_data.code_qr_code is not None + assert ctx.code_qr_code is not None sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) - sha_ctx.update(ctx.display_data.code_qr_code) + sha_ctx.update(ctx.code_qr_code) expected_tag = sha_ctx.digest() if expected_tag != message.tag: print( @@ -331,7 +311,7 @@ async def _handle_qr_code_tag( ) # TODO remove after testing print( "expected code qr code:", - hexlify(ctx.display_data.code_qr_code).decode(), + hexlify(ctx.code_qr_code).decode(), ) # TODO remove after testing print( "expected secret:", hexlify(ctx.qr_code_secret or b"").decode() diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 9d7c97438d..4edfea2a90 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -13,7 +13,6 @@ from trezor.wire.thp import ChannelState, get_enabled_pairing_methods if TYPE_CHECKING: from typing import Awaitable, Container - from trezor import ui from trezor.enums import ThpPairingMethod from trezorui_api import UiResult @@ -26,54 +25,6 @@ if __debug__: from trezor import log -class PairingDisplayData: - - def __init__(self) -> None: - self.code_code_entry: int | None = None - self.code_qr_code: bytes | None = None - self.code_nfc: bytes | None = None - - def get_display_layout(self) -> ui.Layout: - from trezor import ui - - # TODO have different layouts when there is only QR code or only Code Entry - qr_str = "" - code_str = "" - if self.code_qr_code is not None: - qr_str = self.get_code_qr_code_str() - if self.code_code_entry is not None: - code_str = self.get_code_code_entry_str() - - return ui.Layout( - trezorui_api.show_address_details( # noqa - qr_title="Scan QR code to pair", - address=qr_str, - case_sensitive=True, - details_title="", - account="Code to rewrite:\n" + code_str, - path="", - xpubs=[], - ) - ) - - def get_code_code_entry_str(self) -> str: - if self.code_code_entry is not None: - code_str = f"{self.code_code_entry:06}" - if __debug__: - log.debug(__name__, "code_code_entry: %s", code_str) - - return code_str[:3] + " " + code_str[3:] - raise Exception("Code entry string is not available") - - def get_code_qr_code_str(self) -> str: - if self.code_qr_code is not None: - code_str = (hexlify(self.code_qr_code)).decode("utf-8") - if __debug__: - log.debug(__name__, "code_qr_code_hexlified: %s", code_str) - return code_str - raise Exception("QR code string is not available") - - class PairingContext(Context): def __init__(self, channel_ctx: Channel) -> None: @@ -86,11 +37,13 @@ class PairingContext(Context): self.selected_method: ThpPairingMethod + self.code_code_entry: int | None = None + self.code_qr_code: bytes | None = None + self.code_nfc: bytes | None = None # The 2 following attributes are important for NFC pairing self.nfc_secret_host: bytes | None = None self.handshake_hash_host: bytes | None = None - self.display_data: PairingDisplayData = PairingDisplayData() self.cpace: Cpace self.host_name: str | None @@ -117,25 +70,19 @@ class PairingContext(Context): try: next_message = await handle_message(self, message) except Exception as exc: - # Log and ignore. The session handler can only exit explicitly in the + # Log and ignore. The context handler can only exit explicitly in the # following finally block. if __debug__: log.exception(__name__, exc) finally: - # Unload modules imported by the workflow. Should not raise. - # This is not done for the debug session because the snapshot taken - # in a debug session would clear modules which are in use by the - # workflow running on wire. - # TODO utils.unimport_end(modules) - if next_message is None: # Shut down the loop if there is no next message waiting. return # pylint: disable=lost-exception except Exception as exc: - # Log and try again. The session handler can only exit explicitly via - # loop.clear() above. # TODO not updated comments + # Log and try again. The context handler can only exit explicitly via + # finally block above if __debug__: log.exception(__name__, exc) @@ -236,45 +183,75 @@ class PairingContext(Context): self, selected_method: ThpPairingMethod | None = None ) -> UiResult: from trezor.enums import ThpPairingMethod - from trezor.ui.layouts.common import interact if selected_method is None: selected_method = self.selected_method if selected_method is ThpPairingMethod.CodeEntry: - result = await interact( - trezorui_api.show_simple( - title="Copy the following", - text=self.display_data.get_code_code_entry_str(), - ), - br_name="pairing_code_entry", - br_code=ButtonRequestType.Other, - ) - elif selected_method is ThpPairingMethod.QrCode: - result = await interact( - trezorui_api.show_address_details( # noqa - qr_title="Scan QR code to pair", - address=self.display_data.get_code_qr_code_str(), - case_sensitive=True, - details_title="", - account="", - path="", - xpubs=[], - ), - br_name="pairing_qr_code", - br_code=ButtonRequestType.Other, - ) + return await self._show_code_entry_screen() elif selected_method is ThpPairingMethod.NFC: - result = await interact( - trezorui_api.show_simple( - title="NFC Pairing", - text="Move your device close to Trezor", - ), - br_name="pairing_nfc", - br_code=ButtonRequestType.Other, - ) + return await self._show_nfc_screen() + elif selected_method is ThpPairingMethod.QrCode: + return await self._show_qr_code_screen() else: raise Exception("Unknown pairing method") - return result + + async def _show_code_entry_screen(self) -> UiResult: + from trezor.ui.layouts.common import interact + + return await interact( + trezorui_api.show_simple( + title="Copy the following", + text=self._get_code_code_entry_str(), + ), + br_name="pairing_code_entry", + br_code=ButtonRequestType.Other, + ) + + async def _show_nfc_screen(self) -> UiResult: + from trezor.ui.layouts.common import interact + + return await interact( + trezorui_api.show_simple( + title="NFC Pairing", + text="Move your device close to Trezor", + ), + br_name="pairing_nfc", + br_code=ButtonRequestType.Other, + ) + + async def _show_qr_code_screen(self) -> UiResult: + from trezor.ui.layouts.common import interact + + return await interact( + trezorui_api.show_address_details( # noqa + qr_title="Scan QR code to pair", + address=self._get_code_qr_code_str(), + case_sensitive=True, + details_title="", + account="", + path="", + xpubs=[], + ), + br_name="pairing_qr_code", + br_code=ButtonRequestType.Other, + ) + + def _get_code_code_entry_str(self) -> str: + if self.code_code_entry is not None: + code_str = f"{self.code_code_entry:06}" + if __debug__: + log.debug(__name__, "code_code_entry: %s", code_str) + + return code_str[:3] + " " + code_str[3:] + raise Exception("Code entry string is not available") + + def _get_code_qr_code_str(self) -> str: + if self.code_qr_code is not None: + code_str = (hexlify(self.code_qr_code)).decode("utf-8") + if __debug__: + log.debug(__name__, "code_qr_code_hexlified: %s", code_str) + return code_str + raise Exception("QR code string is not available") async def handle_message(