1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-01 18:30:56 +00:00

refactor: improve structure of code for showing pairing screens and storing pairing data

[no changelog]
This commit is contained in:
M1nd3r 2025-01-31 16:11:48 +01:00
parent 91c9f8bcf0
commit 1e5decb2e4
3 changed files with 80 additions and 123 deletions

View File

@ -299,8 +299,8 @@ if __debug__:
return DebugLinkPairingInfo(
channel_id=ctx.channel_id,
handshake_hash=ctx.channel_ctx.get_handshake_hash(),
code_entry_code=ctx.display_data.code_code_entry,
code_qr_code=ctx.display_data.code_qr_code,
code_entry_code=ctx.code_code_entry,
code_qr_code=ctx.code_qr_code,
nfc_secret_trezor=ctx.nfc_secret,
)

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import loop, protobuf
from trezor import protobuf
from trezor.crypto import random
from trezor.crypto.hashlib import sha256
from trezor.enums import ThpMessageType, ThpPairingMethod
@ -27,7 +27,7 @@ from trezor.messages import (
)
from trezor.wire import message_handler
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
from trezor.wire.errors import SilentError, UnexpectedMessage
from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods
from trezor.wire.thp.pairing_context import PairingContext
@ -37,7 +37,7 @@ if __debug__:
from trezor import log
if TYPE_CHECKING:
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
from typing import Any, Callable, Concatenate, ParamSpec, Tuple
P = ParamSpec("P")
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
@ -186,22 +186,6 @@ async def _prepare_pairing(ctx: PairingContext) -> None:
raise Exception() # TODO unknown pairing method
async def show_display_data(
ctx: PairingContext, expected_types: Container[int] = ()
) -> type[protobuf.MessageType]:
from trezorui_api import CANCELLED
read_task = ctx.read(expected_types)
cancel_task = ctx.display_data.get_display_layout()
race = loop.race(read_task, cancel_task.get_result())
result: type[protobuf.MessageType] = await race
if result is CANCELLED:
raise ActionCancelled
return result
@check_state_and_log(ChannelState.TP1)
async def _handle_code_entry_is_selected(ctx: PairingContext) -> None:
if ctx.code_entry_secret is None:
@ -231,16 +215,12 @@ async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None
sha_ctx.update(ctx.code_entry_secret)
sha_ctx.update(challenge_message.challenge)
code_code_entry_hash = sha_ctx.digest()
ctx.display_data.code_code_entry = (
int.from_bytes(code_code_entry_hash, "big") % 1000000
)
ctx.code_code_entry = int.from_bytes(code_code_entry_hash, "big") % 1000000
ctx.cpace = Cpace(
ctx.channel_ctx.get_handshake_hash(),
)
assert ctx.display_data.code_code_entry is not None
ctx.cpace.generate_keys_and_secret(
ctx.display_data.code_code_entry.to_bytes(6, "big")
)
assert ctx.code_code_entry is not None
ctx.cpace.generate_keys_and_secret(ctx.code_code_entry.to_bytes(6, "big"))
await ctx.write_force(
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key)
)
@ -260,7 +240,7 @@ async def _handle_qr_code_is_selected(ctx: PairingContext) -> None:
sha_ctx.update(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.qr_code_secret)
ctx.display_data.code_qr_code = sha_ctx.digest()[:16]
ctx.code_qr_code = sha_ctx.digest()[:16]
await ctx.write_force(ThpPairingPreparationsFinished())
@ -317,9 +297,9 @@ async def _handle_qr_code_tag(
) -> protobuf.MessageType:
if TYPE_CHECKING:
assert isinstance(message, ThpQrCodeTag)
assert ctx.display_data.code_qr_code is not None
assert ctx.code_qr_code is not None
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
sha_ctx.update(ctx.display_data.code_qr_code)
sha_ctx.update(ctx.code_qr_code)
expected_tag = sha_ctx.digest()
if expected_tag != message.tag:
print(
@ -331,7 +311,7 @@ async def _handle_qr_code_tag(
) # TODO remove after testing
print(
"expected code qr code:",
hexlify(ctx.display_data.code_qr_code).decode(),
hexlify(ctx.code_qr_code).decode(),
) # TODO remove after testing
print(
"expected secret:", hexlify(ctx.qr_code_secret or b"").decode()

View File

@ -13,7 +13,6 @@ from trezor.wire.thp import ChannelState, get_enabled_pairing_methods
if TYPE_CHECKING:
from typing import Awaitable, Container
from trezor import ui
from trezor.enums import ThpPairingMethod
from trezorui_api import UiResult
@ -26,54 +25,6 @@ if __debug__:
from trezor import log
class PairingDisplayData:
def __init__(self) -> None:
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc: bytes | None = None
def get_display_layout(self) -> ui.Layout:
from trezor import ui
# TODO have different layouts when there is only QR code or only Code Entry
qr_str = ""
code_str = ""
if self.code_qr_code is not None:
qr_str = self.get_code_qr_code_str()
if self.code_code_entry is not None:
code_str = self.get_code_code_entry_str()
return ui.Layout(
trezorui_api.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=qr_str,
case_sensitive=True,
details_title="",
account="Code to rewrite:\n" + code_str,
path="",
xpubs=[],
)
)
def get_code_code_entry_str(self) -> str:
if self.code_code_entry is not None:
code_str = f"{self.code_code_entry:06}"
if __debug__:
log.debug(__name__, "code_code_entry: %s", code_str)
return code_str[:3] + " " + code_str[3:]
raise Exception("Code entry string is not available")
def get_code_qr_code_str(self) -> str:
if self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
if __debug__:
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
return code_str
raise Exception("QR code string is not available")
class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None:
@ -86,11 +37,13 @@ class PairingContext(Context):
self.selected_method: ThpPairingMethod
self.code_code_entry: int | None = None
self.code_qr_code: bytes | None = None
self.code_nfc: bytes | None = None
# The 2 following attributes are important for NFC pairing
self.nfc_secret_host: bytes | None = None
self.handshake_hash_host: bytes | None = None
self.display_data: PairingDisplayData = PairingDisplayData()
self.cpace: Cpace
self.host_name: str | None
@ -117,25 +70,19 @@ class PairingContext(Context):
try:
next_message = await handle_message(self, message)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# Log and ignore. The context handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None:
# Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above. # TODO not updated comments
# Log and try again. The context handler can only exit explicitly via
# finally block above
if __debug__:
log.exception(__name__, exc)
@ -236,45 +183,75 @@ class PairingContext(Context):
self, selected_method: ThpPairingMethod | None = None
) -> UiResult:
from trezor.enums import ThpPairingMethod
from trezor.ui.layouts.common import interact
if selected_method is None:
selected_method = self.selected_method
if selected_method is ThpPairingMethod.CodeEntry:
result = await interact(
trezorui_api.show_simple(
title="Copy the following",
text=self.display_data.get_code_code_entry_str(),
),
br_name="pairing_code_entry",
br_code=ButtonRequestType.Other,
)
elif selected_method is ThpPairingMethod.QrCode:
result = await interact(
trezorui_api.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=self.display_data.get_code_qr_code_str(),
case_sensitive=True,
details_title="",
account="",
path="",
xpubs=[],
),
br_name="pairing_qr_code",
br_code=ButtonRequestType.Other,
)
return await self._show_code_entry_screen()
elif selected_method is ThpPairingMethod.NFC:
result = await interact(
trezorui_api.show_simple(
title="NFC Pairing",
text="Move your device close to Trezor",
),
br_name="pairing_nfc",
br_code=ButtonRequestType.Other,
)
return await self._show_nfc_screen()
elif selected_method is ThpPairingMethod.QrCode:
return await self._show_qr_code_screen()
else:
raise Exception("Unknown pairing method")
return result
async def _show_code_entry_screen(self) -> UiResult:
from trezor.ui.layouts.common import interact
return await interact(
trezorui_api.show_simple(
title="Copy the following",
text=self._get_code_code_entry_str(),
),
br_name="pairing_code_entry",
br_code=ButtonRequestType.Other,
)
async def _show_nfc_screen(self) -> UiResult:
from trezor.ui.layouts.common import interact
return await interact(
trezorui_api.show_simple(
title="NFC Pairing",
text="Move your device close to Trezor",
),
br_name="pairing_nfc",
br_code=ButtonRequestType.Other,
)
async def _show_qr_code_screen(self) -> UiResult:
from trezor.ui.layouts.common import interact
return await interact(
trezorui_api.show_address_details( # noqa
qr_title="Scan QR code to pair",
address=self._get_code_qr_code_str(),
case_sensitive=True,
details_title="",
account="",
path="",
xpubs=[],
),
br_name="pairing_qr_code",
br_code=ButtonRequestType.Other,
)
def _get_code_code_entry_str(self) -> str:
if self.code_code_entry is not None:
code_str = f"{self.code_code_entry:06}"
if __debug__:
log.debug(__name__, "code_code_entry: %s", code_str)
return code_str[:3] + " " + code_str[3:]
raise Exception("Code entry string is not available")
def _get_code_qr_code_str(self) -> str:
if self.code_qr_code is not None:
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
if __debug__:
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
return code_str
raise Exception("QR code string is not available")
async def handle_message(