1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-01 18:30:56 +00:00

refactor: improve structure of code for showing pairing screens and storing pairing data

[no changelog]
This commit is contained in:
M1nd3r 2025-01-31 16:11:48 +01:00
parent 91c9f8bcf0
commit 1e5decb2e4
3 changed files with 80 additions and 123 deletions

View File

@ -299,8 +299,8 @@ if __debug__:
return DebugLinkPairingInfo( return DebugLinkPairingInfo(
channel_id=ctx.channel_id, channel_id=ctx.channel_id,
handshake_hash=ctx.channel_ctx.get_handshake_hash(), handshake_hash=ctx.channel_ctx.get_handshake_hash(),
code_entry_code=ctx.display_data.code_code_entry, code_entry_code=ctx.code_code_entry,
code_qr_code=ctx.display_data.code_qr_code, code_qr_code=ctx.code_qr_code,
nfc_secret_trezor=ctx.nfc_secret, nfc_secret_trezor=ctx.nfc_secret,
) )

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify from ubinascii import hexlify
from trezor import loop, protobuf from trezor import protobuf
from trezor.crypto import random from trezor.crypto import random
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.enums import ThpMessageType, ThpPairingMethod from trezor.enums import ThpMessageType, ThpPairingMethod
@ -27,7 +27,7 @@ from trezor.messages import (
) )
from trezor.wire import message_handler from trezor.wire import message_handler
from trezor.wire.context import UnexpectedMessageException from trezor.wire.context import UnexpectedMessageException
from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage from trezor.wire.errors import SilentError, UnexpectedMessage
from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods
from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.pairing_context import PairingContext
@ -37,7 +37,7 @@ if __debug__:
from trezor import log from trezor import log
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple from typing import Any, Callable, Concatenate, ParamSpec, Tuple
P = ParamSpec("P") P = ParamSpec("P")
FuncWithContext = Callable[Concatenate[PairingContext, P], Any] FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
@ -186,22 +186,6 @@ async def _prepare_pairing(ctx: PairingContext) -> None:
raise Exception() # TODO unknown pairing method raise Exception() # TODO unknown pairing method
async def show_display_data(
ctx: PairingContext, expected_types: Container[int] = ()
) -> type[protobuf.MessageType]:
from trezorui_api import CANCELLED
read_task = ctx.read(expected_types)
cancel_task = ctx.display_data.get_display_layout()
race = loop.race(read_task, cancel_task.get_result())
result: type[protobuf.MessageType] = await race
if result is CANCELLED:
raise ActionCancelled
return result
@check_state_and_log(ChannelState.TP1) @check_state_and_log(ChannelState.TP1)
async def _handle_code_entry_is_selected(ctx: PairingContext) -> None: async def _handle_code_entry_is_selected(ctx: PairingContext) -> None:
if ctx.code_entry_secret is None: if ctx.code_entry_secret is None:
@ -231,16 +215,12 @@ async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None
sha_ctx.update(ctx.code_entry_secret) sha_ctx.update(ctx.code_entry_secret)
sha_ctx.update(challenge_message.challenge) sha_ctx.update(challenge_message.challenge)
code_code_entry_hash = sha_ctx.digest() code_code_entry_hash = sha_ctx.digest()
ctx.display_data.code_code_entry = ( ctx.code_code_entry = int.from_bytes(code_code_entry_hash, "big") % 1000000
int.from_bytes(code_code_entry_hash, "big") % 1000000
)
ctx.cpace = Cpace( ctx.cpace = Cpace(
ctx.channel_ctx.get_handshake_hash(), ctx.channel_ctx.get_handshake_hash(),
) )
assert ctx.display_data.code_code_entry is not None assert ctx.code_code_entry is not None
ctx.cpace.generate_keys_and_secret( ctx.cpace.generate_keys_and_secret(ctx.code_code_entry.to_bytes(6, "big"))
ctx.display_data.code_code_entry.to_bytes(6, "big")
)
await ctx.write_force( await ctx.write_force(
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key) ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key)
) )
@ -260,7 +240,7 @@ async def _handle_qr_code_is_selected(ctx: PairingContext) -> None:
sha_ctx.update(ctx.channel_ctx.get_handshake_hash()) sha_ctx.update(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.qr_code_secret) sha_ctx.update(ctx.qr_code_secret)
ctx.display_data.code_qr_code = sha_ctx.digest()[:16] ctx.code_qr_code = sha_ctx.digest()[:16]
await ctx.write_force(ThpPairingPreparationsFinished()) await ctx.write_force(ThpPairingPreparationsFinished())
@ -317,9 +297,9 @@ async def _handle_qr_code_tag(
) -> protobuf.MessageType: ) -> protobuf.MessageType:
if TYPE_CHECKING: if TYPE_CHECKING:
assert isinstance(message, ThpQrCodeTag) assert isinstance(message, ThpQrCodeTag)
assert ctx.display_data.code_qr_code is not None assert ctx.code_qr_code is not None
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash()) sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.display_data.code_qr_code) sha_ctx.update(ctx.code_qr_code)
expected_tag = sha_ctx.digest() expected_tag = sha_ctx.digest()
if expected_tag != message.tag: if expected_tag != message.tag:
print( print(
@ -331,7 +311,7 @@ async def _handle_qr_code_tag(
) # TODO remove after testing ) # TODO remove after testing
print( print(
"expected code qr code:", "expected code qr code:",
hexlify(ctx.display_data.code_qr_code).decode(), hexlify(ctx.code_qr_code).decode(),
) # TODO remove after testing ) # TODO remove after testing
print( print(
"expected secret:", hexlify(ctx.qr_code_secret or b"").decode() "expected secret:", hexlify(ctx.qr_code_secret or b"").decode()

View File

@ -13,7 +13,6 @@ from trezor.wire.thp import ChannelState, get_enabled_pairing_methods
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable, Container from typing import Awaitable, Container
from trezor import ui
from trezor.enums import ThpPairingMethod from trezor.enums import ThpPairingMethod
from trezorui_api import UiResult from trezorui_api import UiResult
@ -26,54 +25,6 @@ if __debug__:
from trezor import log from trezor import log
class PairingDisplayData:
def __init__(self) -> None:
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc: bytes | None = None
def get_display_layout(self) -> ui.Layout:
from trezor import ui
# TODO have different layouts when there is only QR code or only Code Entry
qr_str = ""
code_str = ""
if self.code_qr_code is not None:
qr_str = self.get_code_qr_code_str()
if self.code_code_entry is not None:
code_str = self.get_code_code_entry_str()
return ui.Layout(
trezorui_api.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=qr_str,
case_sensitive=True,
details_title="",
account="Code to rewrite:\n" + code_str,
path="",
xpubs=[],
)
)
def get_code_code_entry_str(self) -> str:
if self.code_code_entry is not None:
code_str = f"{self.code_code_entry:06}"
if __debug__:
log.debug(__name__, "code_code_entry: %s", code_str)
return code_str[:3] + " " + code_str[3:]
raise Exception("Code entry string is not available")
def get_code_qr_code_str(self) -> str:
if self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
if __debug__:
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
return code_str
raise Exception("QR code string is not available")
class PairingContext(Context): class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None: def __init__(self, channel_ctx: Channel) -> None:
@ -86,11 +37,13 @@ class PairingContext(Context):
self.selected_method: ThpPairingMethod self.selected_method: ThpPairingMethod
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc: bytes | None = None
# The 2 following attributes are important for NFC pairing # The 2 following attributes are important for NFC pairing
self.nfc_secret_host: bytes | None = None self.nfc_secret_host: bytes | None = None
self.handshake_hash_host: bytes | None = None self.handshake_hash_host: bytes | None = None
self.display_data: PairingDisplayData = PairingDisplayData()
self.cpace: Cpace self.cpace: Cpace
self.host_name: str | None self.host_name: str | None
@ -117,25 +70,19 @@ class PairingContext(Context):
try: try:
next_message = await handle_message(self, message) next_message = await handle_message(self, message)
except Exception as exc: except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the # Log and ignore. The context handler can only exit explicitly in the
# following finally block. # following finally block.
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
finally: finally:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None: if next_message is None:
# Shut down the loop if there is no next message waiting. # Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception return # pylint: disable=lost-exception
except Exception as exc: except Exception as exc:
# Log and try again. The session handler can only exit explicitly via # Log and try again. The context handler can only exit explicitly via
# loop.clear() above. # TODO not updated comments # finally block above
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
@ -236,45 +183,75 @@ class PairingContext(Context):
self, selected_method: ThpPairingMethod | None = None self, selected_method: ThpPairingMethod | None = None
) -> UiResult: ) -> UiResult:
from trezor.enums import ThpPairingMethod from trezor.enums import ThpPairingMethod
from trezor.ui.layouts.common import interact
if selected_method is None: if selected_method is None:
selected_method = self.selected_method selected_method = self.selected_method
if selected_method is ThpPairingMethod.CodeEntry: if selected_method is ThpPairingMethod.CodeEntry:
result = await interact( return await self._show_code_entry_screen()
trezorui_api.show_simple(
title="Copy the following",
text=self.display_data.get_code_code_entry_str(),
),
br_name="pairing_code_entry",
br_code=ButtonRequestType.Other,
)
elif selected_method is ThpPairingMethod.QrCode:
result = await interact(
trezorui_api.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=self.display_data.get_code_qr_code_str(),
case_sensitive=True,
details_title="",
account="",
path="",
xpubs=[],
),
br_name="pairing_qr_code",
br_code=ButtonRequestType.Other,
)
elif selected_method is ThpPairingMethod.NFC: elif selected_method is ThpPairingMethod.NFC:
result = await interact( return await self._show_nfc_screen()
trezorui_api.show_simple( elif selected_method is ThpPairingMethod.QrCode:
title="NFC Pairing", return await self._show_qr_code_screen()
text="Move your device close to Trezor",
),
br_name="pairing_nfc",
br_code=ButtonRequestType.Other,
)
else: else:
raise Exception("Unknown pairing method") raise Exception("Unknown pairing method")
return result
async def _show_code_entry_screen(self) -> UiResult:
from trezor.ui.layouts.common import interact
return await interact(
trezorui_api.show_simple(
title="Copy the following",
text=self._get_code_code_entry_str(),
),
br_name="pairing_code_entry",
br_code=ButtonRequestType.Other,
)
async def _show_nfc_screen(self) -> UiResult:
from trezor.ui.layouts.common import interact
return await interact(
trezorui_api.show_simple(
title="NFC Pairing",
text="Move your device close to Trezor",
),
br_name="pairing_nfc",
br_code=ButtonRequestType.Other,
)
async def _show_qr_code_screen(self) -> UiResult:
from trezor.ui.layouts.common import interact
return await interact(
trezorui_api.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=self._get_code_qr_code_str(),
case_sensitive=True,
details_title="",
account="",
path="",
xpubs=[],
),
br_name="pairing_qr_code",
br_code=ButtonRequestType.Other,
)
def _get_code_code_entry_str(self) -> str:
if self.code_code_entry is not None:
code_str = f"{self.code_code_entry:06}"
if __debug__:
log.debug(__name__, "code_code_entry: %s", code_str)
return code_str[:3] + " " + code_str[3:]
raise Exception("Code entry string is not available")
def _get_code_qr_code_str(self) -> str:
if self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
if __debug__:
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
return code_str
raise Exception("QR code string is not available")
async def handle_message( async def handle_message(