mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-01 18:30:56 +00:00
refactor: improve structure of code for showing pairing screens and storing pairing data
[no changelog]
This commit is contained in:
parent
91c9f8bcf0
commit
1e5decb2e4
@ -299,8 +299,8 @@ if __debug__:
|
|||||||
return DebugLinkPairingInfo(
|
return DebugLinkPairingInfo(
|
||||||
channel_id=ctx.channel_id,
|
channel_id=ctx.channel_id,
|
||||||
handshake_hash=ctx.channel_ctx.get_handshake_hash(),
|
handshake_hash=ctx.channel_ctx.get_handshake_hash(),
|
||||||
code_entry_code=ctx.display_data.code_code_entry,
|
code_entry_code=ctx.code_code_entry,
|
||||||
code_qr_code=ctx.display_data.code_qr_code,
|
code_qr_code=ctx.code_qr_code,
|
||||||
nfc_secret_trezor=ctx.nfc_secret,
|
nfc_secret_trezor=ctx.nfc_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from ubinascii import hexlify
|
from ubinascii import hexlify
|
||||||
|
|
||||||
from trezor import loop, protobuf
|
from trezor import protobuf
|
||||||
from trezor.crypto import random
|
from trezor.crypto import random
|
||||||
from trezor.crypto.hashlib import sha256
|
from trezor.crypto.hashlib import sha256
|
||||||
from trezor.enums import ThpMessageType, ThpPairingMethod
|
from trezor.enums import ThpMessageType, ThpPairingMethod
|
||||||
@ -27,7 +27,7 @@ from trezor.messages import (
|
|||||||
)
|
)
|
||||||
from trezor.wire import message_handler
|
from trezor.wire import message_handler
|
||||||
from trezor.wire.context import UnexpectedMessageException
|
from trezor.wire.context import UnexpectedMessageException
|
||||||
from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
|
from trezor.wire.errors import SilentError, UnexpectedMessage
|
||||||
from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods
|
from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods
|
||||||
from trezor.wire.thp.pairing_context import PairingContext
|
from trezor.wire.thp.pairing_context import PairingContext
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ if __debug__:
|
|||||||
from trezor import log
|
from trezor import log
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any, Callable, Concatenate, Container, ParamSpec, Tuple
|
from typing import Any, Callable, Concatenate, ParamSpec, Tuple
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
|
FuncWithContext = Callable[Concatenate[PairingContext, P], Any]
|
||||||
@ -186,22 +186,6 @@ async def _prepare_pairing(ctx: PairingContext) -> None:
|
|||||||
raise Exception() # TODO unknown pairing method
|
raise Exception() # TODO unknown pairing method
|
||||||
|
|
||||||
|
|
||||||
async def show_display_data(
|
|
||||||
ctx: PairingContext, expected_types: Container[int] = ()
|
|
||||||
) -> type[protobuf.MessageType]:
|
|
||||||
from trezorui_api import CANCELLED
|
|
||||||
|
|
||||||
read_task = ctx.read(expected_types)
|
|
||||||
cancel_task = ctx.display_data.get_display_layout()
|
|
||||||
race = loop.race(read_task, cancel_task.get_result())
|
|
||||||
result: type[protobuf.MessageType] = await race
|
|
||||||
|
|
||||||
if result is CANCELLED:
|
|
||||||
raise ActionCancelled
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@check_state_and_log(ChannelState.TP1)
|
@check_state_and_log(ChannelState.TP1)
|
||||||
async def _handle_code_entry_is_selected(ctx: PairingContext) -> None:
|
async def _handle_code_entry_is_selected(ctx: PairingContext) -> None:
|
||||||
if ctx.code_entry_secret is None:
|
if ctx.code_entry_secret is None:
|
||||||
@ -231,16 +215,12 @@ async def _handle_code_entry_is_selected_first_time(ctx: PairingContext) -> None
|
|||||||
sha_ctx.update(ctx.code_entry_secret)
|
sha_ctx.update(ctx.code_entry_secret)
|
||||||
sha_ctx.update(challenge_message.challenge)
|
sha_ctx.update(challenge_message.challenge)
|
||||||
code_code_entry_hash = sha_ctx.digest()
|
code_code_entry_hash = sha_ctx.digest()
|
||||||
ctx.display_data.code_code_entry = (
|
ctx.code_code_entry = int.from_bytes(code_code_entry_hash, "big") % 1000000
|
||||||
int.from_bytes(code_code_entry_hash, "big") % 1000000
|
|
||||||
)
|
|
||||||
ctx.cpace = Cpace(
|
ctx.cpace = Cpace(
|
||||||
ctx.channel_ctx.get_handshake_hash(),
|
ctx.channel_ctx.get_handshake_hash(),
|
||||||
)
|
)
|
||||||
assert ctx.display_data.code_code_entry is not None
|
assert ctx.code_code_entry is not None
|
||||||
ctx.cpace.generate_keys_and_secret(
|
ctx.cpace.generate_keys_and_secret(ctx.code_code_entry.to_bytes(6, "big"))
|
||||||
ctx.display_data.code_code_entry.to_bytes(6, "big")
|
|
||||||
)
|
|
||||||
await ctx.write_force(
|
await ctx.write_force(
|
||||||
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key)
|
ThpCodeEntryCpaceTrezor(cpace_trezor_public_key=ctx.cpace.trezor_public_key)
|
||||||
)
|
)
|
||||||
@ -260,7 +240,7 @@ async def _handle_qr_code_is_selected(ctx: PairingContext) -> None:
|
|||||||
sha_ctx.update(ctx.channel_ctx.get_handshake_hash())
|
sha_ctx.update(ctx.channel_ctx.get_handshake_hash())
|
||||||
sha_ctx.update(ctx.qr_code_secret)
|
sha_ctx.update(ctx.qr_code_secret)
|
||||||
|
|
||||||
ctx.display_data.code_qr_code = sha_ctx.digest()[:16]
|
ctx.code_qr_code = sha_ctx.digest()[:16]
|
||||||
await ctx.write_force(ThpPairingPreparationsFinished())
|
await ctx.write_force(ThpPairingPreparationsFinished())
|
||||||
|
|
||||||
|
|
||||||
@ -317,9 +297,9 @@ async def _handle_qr_code_tag(
|
|||||||
) -> protobuf.MessageType:
|
) -> protobuf.MessageType:
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert isinstance(message, ThpQrCodeTag)
|
assert isinstance(message, ThpQrCodeTag)
|
||||||
assert ctx.display_data.code_qr_code is not None
|
assert ctx.code_qr_code is not None
|
||||||
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
|
sha_ctx = sha256(ctx.channel_ctx.get_handshake_hash())
|
||||||
sha_ctx.update(ctx.display_data.code_qr_code)
|
sha_ctx.update(ctx.code_qr_code)
|
||||||
expected_tag = sha_ctx.digest()
|
expected_tag = sha_ctx.digest()
|
||||||
if expected_tag != message.tag:
|
if expected_tag != message.tag:
|
||||||
print(
|
print(
|
||||||
@ -331,7 +311,7 @@ async def _handle_qr_code_tag(
|
|||||||
) # TODO remove after testing
|
) # TODO remove after testing
|
||||||
print(
|
print(
|
||||||
"expected code qr code:",
|
"expected code qr code:",
|
||||||
hexlify(ctx.display_data.code_qr_code).decode(),
|
hexlify(ctx.code_qr_code).decode(),
|
||||||
) # TODO remove after testing
|
) # TODO remove after testing
|
||||||
print(
|
print(
|
||||||
"expected secret:", hexlify(ctx.qr_code_secret or b"").decode()
|
"expected secret:", hexlify(ctx.qr_code_secret or b"").decode()
|
||||||
|
@ -13,7 +13,6 @@ from trezor.wire.thp import ChannelState, get_enabled_pairing_methods
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Awaitable, Container
|
from typing import Awaitable, Container
|
||||||
|
|
||||||
from trezor import ui
|
|
||||||
from trezor.enums import ThpPairingMethod
|
from trezor.enums import ThpPairingMethod
|
||||||
from trezorui_api import UiResult
|
from trezorui_api import UiResult
|
||||||
|
|
||||||
@ -26,54 +25,6 @@ if __debug__:
|
|||||||
from trezor import log
|
from trezor import log
|
||||||
|
|
||||||
|
|
||||||
class PairingDisplayData:
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.code_code_entry: int | None = None
|
|
||||||
self.code_qr_code: bytes | None = None
|
|
||||||
self.code_nfc: bytes | None = None
|
|
||||||
|
|
||||||
def get_display_layout(self) -> ui.Layout:
|
|
||||||
from trezor import ui
|
|
||||||
|
|
||||||
# TODO have different layouts when there is only QR code or only Code Entry
|
|
||||||
qr_str = ""
|
|
||||||
code_str = ""
|
|
||||||
if self.code_qr_code is not None:
|
|
||||||
qr_str = self.get_code_qr_code_str()
|
|
||||||
if self.code_code_entry is not None:
|
|
||||||
code_str = self.get_code_code_entry_str()
|
|
||||||
|
|
||||||
return ui.Layout(
|
|
||||||
trezorui_api.show_address_details( # noqa
|
|
||||||
qr_title="Scan QR code to pair",
|
|
||||||
address=qr_str,
|
|
||||||
case_sensitive=True,
|
|
||||||
details_title="",
|
|
||||||
account="Code to rewrite:\n" + code_str,
|
|
||||||
path="",
|
|
||||||
xpubs=[],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_code_code_entry_str(self) -> str:
|
|
||||||
if self.code_code_entry is not None:
|
|
||||||
code_str = f"{self.code_code_entry:06}"
|
|
||||||
if __debug__:
|
|
||||||
log.debug(__name__, "code_code_entry: %s", code_str)
|
|
||||||
|
|
||||||
return code_str[:3] + " " + code_str[3:]
|
|
||||||
raise Exception("Code entry string is not available")
|
|
||||||
|
|
||||||
def get_code_qr_code_str(self) -> str:
|
|
||||||
if self.code_qr_code is not None:
|
|
||||||
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
|
|
||||||
if __debug__:
|
|
||||||
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
|
|
||||||
return code_str
|
|
||||||
raise Exception("QR code string is not available")
|
|
||||||
|
|
||||||
|
|
||||||
class PairingContext(Context):
|
class PairingContext(Context):
|
||||||
|
|
||||||
def __init__(self, channel_ctx: Channel) -> None:
|
def __init__(self, channel_ctx: Channel) -> None:
|
||||||
@ -86,11 +37,13 @@ class PairingContext(Context):
|
|||||||
|
|
||||||
self.selected_method: ThpPairingMethod
|
self.selected_method: ThpPairingMethod
|
||||||
|
|
||||||
|
self.code_code_entry: int | None = None
|
||||||
|
self.code_qr_code: bytes | None = None
|
||||||
|
self.code_nfc: bytes | None = None
|
||||||
# The 2 following attributes are important for NFC pairing
|
# The 2 following attributes are important for NFC pairing
|
||||||
self.nfc_secret_host: bytes | None = None
|
self.nfc_secret_host: bytes | None = None
|
||||||
self.handshake_hash_host: bytes | None = None
|
self.handshake_hash_host: bytes | None = None
|
||||||
|
|
||||||
self.display_data: PairingDisplayData = PairingDisplayData()
|
|
||||||
self.cpace: Cpace
|
self.cpace: Cpace
|
||||||
self.host_name: str | None
|
self.host_name: str | None
|
||||||
|
|
||||||
@ -117,25 +70,19 @@ class PairingContext(Context):
|
|||||||
try:
|
try:
|
||||||
next_message = await handle_message(self, message)
|
next_message = await handle_message(self, message)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# Log and ignore. The session handler can only exit explicitly in the
|
# Log and ignore. The context handler can only exit explicitly in the
|
||||||
# following finally block.
|
# following finally block.
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
finally:
|
finally:
|
||||||
# Unload modules imported by the workflow. Should not raise.
|
|
||||||
# This is not done for the debug session because the snapshot taken
|
|
||||||
# in a debug session would clear modules which are in use by the
|
|
||||||
# workflow running on wire.
|
|
||||||
# TODO utils.unimport_end(modules)
|
|
||||||
|
|
||||||
if next_message is None:
|
if next_message is None:
|
||||||
|
|
||||||
# Shut down the loop if there is no next message waiting.
|
# Shut down the loop if there is no next message waiting.
|
||||||
return # pylint: disable=lost-exception
|
return # pylint: disable=lost-exception
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# Log and try again. The session handler can only exit explicitly via
|
# Log and try again. The context handler can only exit explicitly via
|
||||||
# loop.clear() above. # TODO not updated comments
|
# finally block above
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, exc)
|
log.exception(__name__, exc)
|
||||||
|
|
||||||
@ -236,45 +183,75 @@ class PairingContext(Context):
|
|||||||
self, selected_method: ThpPairingMethod | None = None
|
self, selected_method: ThpPairingMethod | None = None
|
||||||
) -> UiResult:
|
) -> UiResult:
|
||||||
from trezor.enums import ThpPairingMethod
|
from trezor.enums import ThpPairingMethod
|
||||||
from trezor.ui.layouts.common import interact
|
|
||||||
|
|
||||||
if selected_method is None:
|
if selected_method is None:
|
||||||
selected_method = self.selected_method
|
selected_method = self.selected_method
|
||||||
if selected_method is ThpPairingMethod.CodeEntry:
|
if selected_method is ThpPairingMethod.CodeEntry:
|
||||||
result = await interact(
|
return await self._show_code_entry_screen()
|
||||||
trezorui_api.show_simple(
|
|
||||||
title="Copy the following",
|
|
||||||
text=self.display_data.get_code_code_entry_str(),
|
|
||||||
),
|
|
||||||
br_name="pairing_code_entry",
|
|
||||||
br_code=ButtonRequestType.Other,
|
|
||||||
)
|
|
||||||
elif selected_method is ThpPairingMethod.QrCode:
|
|
||||||
result = await interact(
|
|
||||||
trezorui_api.show_address_details( # noqa
|
|
||||||
qr_title="Scan QR code to pair",
|
|
||||||
address=self.display_data.get_code_qr_code_str(),
|
|
||||||
case_sensitive=True,
|
|
||||||
details_title="",
|
|
||||||
account="",
|
|
||||||
path="",
|
|
||||||
xpubs=[],
|
|
||||||
),
|
|
||||||
br_name="pairing_qr_code",
|
|
||||||
br_code=ButtonRequestType.Other,
|
|
||||||
)
|
|
||||||
elif selected_method is ThpPairingMethod.NFC:
|
elif selected_method is ThpPairingMethod.NFC:
|
||||||
result = await interact(
|
return await self._show_nfc_screen()
|
||||||
trezorui_api.show_simple(
|
elif selected_method is ThpPairingMethod.QrCode:
|
||||||
title="NFC Pairing",
|
return await self._show_qr_code_screen()
|
||||||
text="Move your device close to Trezor",
|
|
||||||
),
|
|
||||||
br_name="pairing_nfc",
|
|
||||||
br_code=ButtonRequestType.Other,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Unknown pairing method")
|
raise Exception("Unknown pairing method")
|
||||||
return result
|
|
||||||
|
async def _show_code_entry_screen(self) -> UiResult:
|
||||||
|
from trezor.ui.layouts.common import interact
|
||||||
|
|
||||||
|
return await interact(
|
||||||
|
trezorui_api.show_simple(
|
||||||
|
title="Copy the following",
|
||||||
|
text=self._get_code_code_entry_str(),
|
||||||
|
),
|
||||||
|
br_name="pairing_code_entry",
|
||||||
|
br_code=ButtonRequestType.Other,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _show_nfc_screen(self) -> UiResult:
|
||||||
|
from trezor.ui.layouts.common import interact
|
||||||
|
|
||||||
|
return await interact(
|
||||||
|
trezorui_api.show_simple(
|
||||||
|
title="NFC Pairing",
|
||||||
|
text="Move your device close to Trezor",
|
||||||
|
),
|
||||||
|
br_name="pairing_nfc",
|
||||||
|
br_code=ButtonRequestType.Other,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _show_qr_code_screen(self) -> UiResult:
|
||||||
|
from trezor.ui.layouts.common import interact
|
||||||
|
|
||||||
|
return await interact(
|
||||||
|
trezorui_api.show_address_details( # noqa
|
||||||
|
qr_title="Scan QR code to pair",
|
||||||
|
address=self._get_code_qr_code_str(),
|
||||||
|
case_sensitive=True,
|
||||||
|
details_title="",
|
||||||
|
account="",
|
||||||
|
path="",
|
||||||
|
xpubs=[],
|
||||||
|
),
|
||||||
|
br_name="pairing_qr_code",
|
||||||
|
br_code=ButtonRequestType.Other,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_code_code_entry_str(self) -> str:
|
||||||
|
if self.code_code_entry is not None:
|
||||||
|
code_str = f"{self.code_code_entry:06}"
|
||||||
|
if __debug__:
|
||||||
|
log.debug(__name__, "code_code_entry: %s", code_str)
|
||||||
|
|
||||||
|
return code_str[:3] + " " + code_str[3:]
|
||||||
|
raise Exception("Code entry string is not available")
|
||||||
|
|
||||||
|
def _get_code_qr_code_str(self) -> str:
|
||||||
|
if self.code_qr_code is not None:
|
||||||
|
code_str = (hexlify(self.code_qr_code)).decode("utf-8")
|
||||||
|
if __debug__:
|
||||||
|
log.debug(__name__, "code_qr_code_hexlified: %s", code_str)
|
||||||
|
return code_str
|
||||||
|
raise Exception("QR code string is not available")
|
||||||
|
|
||||||
|
|
||||||
async def handle_message(
|
async def handle_message(
|
||||||
|
Loading…
Reference in New Issue
Block a user