diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 9357bd100e..23f8f38165 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -208,6 +208,7 @@ if __debug__: msg: DebugLinkDecision, ) -> DebugLinkState | None: from trezor import ui, workflow + log.debug(__name__, "decision 1") workflow.idle_timer.touch() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 9289282e48..e7a7b0597f 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -51,6 +51,7 @@ LOG = logging.getLogger(__name__) class TrezorClient: button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None _management_session: Session | None = None @@ -135,8 +136,8 @@ class TrezorClient: """ Note: this function potentially modifies the input session. """ - from trezorlib.transport.session import SessionV1, SessionV2 - from trezorlib.debuglink import SessionDebugWrapper + from .debuglink import SessionDebugWrapper + from .transport.session import SessionV1, SessionV2 if isinstance(session, SessionDebugWrapper): session = session._session diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 00bbdbd27e..df0587c351 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -32,10 +32,10 @@ from pathlib import Path from mnemonic import Mnemonic from . import btc, mapping, messages, models, protobuf -from .client import TrezorClient -from .exceptions import TrezorFailure +from .client import MAX_PASSPHRASE_LENGTH, PASSPHRASE_ON_DEVICE, TrezorClient +from .exceptions import Cancelled, TrezorFailure from .log import DUMP_BYTES -from .messages import DebugWaitType +from .messages import Capability, DebugWaitType from .tools import expect, parse_path from .transport.session import Session, SessionV1, SessionV2 from .transport.thp.protocol_v1 import ProtocolV1 @@ -553,7 +553,7 @@ class DebugLink: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: # Next layout change will be caused by external event - # (e.g. device being auto-locked or as a result of device_handler.run(xxx)) + # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx)) # and not by our debug actions/decisions. # Resetting the debug state so we wait for the next layout change # (and do not return the current state). @@ -896,11 +896,6 @@ class DebugUI: else: self.debuglink.press_yes() - def debug_callback_button(self, session: Session, msg: t.Any) -> t.Any: - session._write(messages.ButtonAck()) - self.button_request(msg) - return session._read() - def button_request(self, br: messages.ButtonRequest) -> None: self.debuglink.snapshot_legacy() @@ -1337,7 +1332,68 @@ class TrezorClientDebugLink(TrezorClient): @property def button_callback(self): - return self.ui.debug_callback_button + + def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() + + return _callback_button + + @property + def passphrase_callback(self): + def _callback_passphrase( + session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) + + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> t.Any: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + # session.session_id = resp.state + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp + + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if isinstance(session, SessionV1): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + else: + passphrase = session.passphrase + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: + session.call_raw(messages.Cancel()) + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) + + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") + + return send_passphrase(passphrase, on_device=False) + + return _callback_passphrase def ensure_open(self) -> None: """Only open session if there isn't already an open one.""" diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index a265e8a041..aff585d86c 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -3,11 +3,6 @@ from __future__ import annotations import logging import typing as t -from mnemonic import Mnemonic -from ..client import MAX_PASSPHRASE_LENGTH, PASSPHRASE_ON_DEVICE - -from ..messages import Capability - from .. import exceptions, messages, models from .thp.protocol_v1 import ProtocolV1 from .thp.protocol_v2 import ProtocolV2 @@ -48,6 +43,7 @@ class Session: elif isinstance(resp, messages.PassphraseRequest): if self.passphrase_callback is None: raise Exception # TODO + print(self.passphrase_callback) resp = self.passphrase_callback(self, resp) elif isinstance(resp, messages.ButtonRequest): if self.button_callback is None: @@ -95,6 +91,7 @@ class Session: class SessionV1(Session): derive_cardano: bool = False + @classmethod def new( cls, client: TrezorClient, passphrase: str = "", derive_cardano: bool = False @@ -108,7 +105,7 @@ class SessionV1(Session): session = SessionV1(client, session_id) session.button_callback = client.button_callback session.pin_callback = client.pin_callback - session.passphrase_callback = _callback_passphrase + session.passphrase_callback = client.passphrase_callback session.passphrase = passphrase session.derive_cardano = derive_cardano session.init_session() @@ -142,46 +139,6 @@ def _callback_button(session: Session, msg: t.Any) -> t.Any: return session.call(messages.ButtonAck()) -def _callback_passphrase(session: Session, msg: messages.PassphraseRequest) -> t.Any: - available_on_device = Capability.PassphraseEntry in session.features.capabilities - def send_passphrase( - passphrase: str | None = None, on_device: bool | None = None - ) -> t.Any: - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = session.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - session.session_id = resp.state - resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - passphrase = session.passphrase - except exceptions.Cancelled: - session.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - session.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - session.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - class SessionV2(Session): @classmethod diff --git a/tests/click_tests/test_autolock.py b/tests/click_tests/test_autolock.py index 12669bb860..25926df1c0 100644 --- a/tests/click_tests/test_autolock.py +++ b/tests/click_tests/test_autolock.py @@ -22,6 +22,7 @@ import pytest from trezorlib import btc, device, exceptions, messages from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.protobuf import MessageType from trezorlib.tools import parse_path @@ -58,7 +59,7 @@ CENTER_BUTTON = buttons.grid35(1, 2) def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int): debug = device_handler.debuglink() - + Session(device_handler.client.get_management_session()).lock() device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore assert "PinKeyboard" in debug.read_layout().all_components() @@ -97,7 +98,7 @@ def test_autolock_interrupts_signing(device_handler: "BackgroundDeviceHandler"): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore + device_handler.run_with_session(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore assert ( "1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1" @@ -132,6 +133,10 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() + + # Prepare session to use later + session = Session(device_handler.client.get_session()) + # try to sign a transaction inp1 = messages.TxInputType( address_n=parse_path("86h/0h/0h/0/0"), @@ -147,8 +152,8 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa script_type=messages.OutputScriptType.PAYTOADDRESS, ) - device_handler.run( - btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET + device_handler.run_with_provided_session( + session, btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET ) assert ( @@ -175,11 +180,11 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.TxAck, None) + session.set_filter(messages.TxAck, None) return msg - with device_handler.client: - device_handler.client.set_filter(messages.TxAck, sleepy_filter) + with session, device_handler.client: + session.set_filter(messages.TxAck, sleepy_filter) # confirm transaction # In all cases we set wait=False to avoid waiting for the screen and triggering # the layout deadlock detection. In reality there is no deadlock but the @@ -187,7 +192,7 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa # timeout is 3. In this test we don't need the result of the input event so # waiting for it is not necessary. if debug.layout_type is LayoutType.TT: - debug.click(buttons.OK, wait=False) + debug.click(buttons.OK, hold_ms=1000, wait=False) elif debug.layout_type is LayoutType.Mercury: debug.click(buttons.TAP_TO_CONFIRM, wait=False) elif debug.layout_type is LayoutType.TR: @@ -196,7 +201,6 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa signatures, tx = device_handler.result() assert len(signatures) == 1 assert tx - assert device_handler.features().unlocked is False @@ -206,7 +210,7 @@ def test_autolock_passphrase_keyboard(device_handler: "BackgroundDeviceHandler") debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore + device_handler.run_with_session(common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() @@ -248,7 +252,7 @@ def test_autolock_interrupts_passphrase(device_handler: "BackgroundDeviceHandler debug = device_handler.debuglink() # get address - device_handler.run(common.get_test_address) # type: ignore + device_handler.run_with_session(common.get_test_address) # type: ignore assert "PassphraseKeyboard" in debug.read_layout().all_components() @@ -287,7 +291,7 @@ def test_dryrun_locks_at_number_of_words(device_handler: "BackgroundDeviceHandle set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) layout = unlock_dry_run(debug) assert TR.recovery__num_of_words in debug.read_layout().text_content() @@ -319,7 +323,7 @@ def test_dryrun_locks_at_word_entry(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -345,7 +349,7 @@ def test_dryrun_enter_word_slowly(device_handler: "BackgroundDeviceHandler"): set_autolock_delay(device_handler, 10_000) debug = device_handler.debuglink() - device_handler.run(device.recover, type=messages.RecoveryType.DryRun) + device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun) unlock_dry_run(debug) @@ -405,7 +409,11 @@ def test_autolock_does_not_interrupt_preauthorized( debug = device_handler.debuglink() - device_handler.run( + # Prepare session to use later + session = Session(device_handler.client.get_session()) + + device_handler.run_with_provided_session( + session, btc.authorize_coinjoin, coordinator="www.example.com", max_rounds=2, @@ -519,14 +527,15 @@ def test_autolock_does_not_interrupt_preauthorized( def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - device_handler.client.set_filter(messages.SignTx, None) + session.set_filter(messages.SignTx, None) return msg - with device_handler.client: + with session: # Start DoPreauthorized flow when device is unlocked. Wait 10s before # delivering SignTx, by that time autolock timer should have fired. - device_handler.client.set_filter(messages.SignTx, sleepy_filter) - device_handler.run( + session.set_filter(messages.SignTx, sleepy_filter) + device_handler.run_with_provided_session( + session, btc.sign_tx, "Testnet", inputs, diff --git a/tests/click_tests/test_backup_slip39_custom.py b/tests/click_tests/test_backup_slip39_custom.py index be01683d07..0976a08ad3 100644 --- a/tests/click_tests/test_backup_slip39_custom.py +++ b/tests/click_tests/test_backup_slip39_custom.py @@ -53,7 +53,9 @@ def test_backup_slip39_custom( assert features.initialized is False - device_handler.run( + session = device_handler.client.get_management_session() + device_handler.run_with_provided_session( + session, device.reset, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -68,7 +70,7 @@ def test_backup_slip39_custom( assert device_handler.result() == "Initialized" - device_handler.run( + device_handler.run_with_session( device.backup, group_threshold=group_threshold, groups=[(share_threshold, share_count)], diff --git a/tests/click_tests/test_lock.py b/tests/click_tests/test_lock.py index b3656dfd29..afaacb078c 100644 --- a/tests/click_tests/test_lock.py +++ b/tests/click_tests/test_lock.py @@ -65,20 +65,22 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"): device_handler.run_with_session(common.get_test_address) assert "PinKeyboard" in debug.read_layout().all_components() - time.sleep(10) debug.input("1234") assert device_handler.result() + session.refresh_features() assert device_handler.features().unlocked is True # short touch hold(short_duration) time.sleep(0.5) # so that the homescreen appears again (hacky) + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False # unlock by touching @@ -89,8 +91,10 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"): assert "PinKeyboard" in layout.all_components() debug.input("1234") + session.refresh_features() assert device_handler.features().unlocked is True # lock hold(lock_duration) + session.refresh_features() assert device_handler.features().unlocked is False diff --git a/tests/click_tests/test_passphrase_mercury.py b/tests/click_tests/test_passphrase_mercury.py index 9bed04da84..d0783e0dcd 100644 --- a/tests/click_tests/test_passphrase_mercury.py +++ b/tests/click_tests/test_passphrase_mercury.py @@ -97,7 +97,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore # TODO assert debug.read_layout().main_component() == "PassphraseKeyboard" diff --git a/tests/click_tests/test_passphrase_tr.py b/tests/click_tests/test_passphrase_tr.py index 57685451ba..0affa4fbb6 100644 --- a/tests/click_tests/test_passphrase_tr.py +++ b/tests/click_tests/test_passphrase_tr.py @@ -91,7 +91,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore layout = debug.read_layout() assert "PassphraseKeyboard" in layout.all_components() assert layout.passphrase() == "" diff --git a/tests/click_tests/test_passphrase_tt.py b/tests/click_tests/test_passphrase_tt.py index 8f490c0309..79993b954f 100644 --- a/tests/click_tests/test_passphrase_tt.py +++ b/tests/click_tests/test_passphrase_tt.py @@ -69,7 +69,7 @@ def prepare_passphrase_dialogue( device_handler: "BackgroundDeviceHandler", address: Optional[str] = None ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(get_test_address) # type: ignore + device_handler.run_with_session(get_test_address) # type: ignore assert debug.read_layout().main_component() == "PassphraseKeyboard" # Resetting the category as it could have been changed by previous tests diff --git a/tests/click_tests/test_pin.py b/tests/click_tests/test_pin.py index 48f54c5573..4d8afedb31 100644 --- a/tests/click_tests/test_pin.py +++ b/tests/click_tests/test_pin.py @@ -23,6 +23,7 @@ import pytest from trezorlib import device, exceptions from trezorlib.debuglink import LayoutType +from trezorlib.debuglink import SessionDebugWrapper as Session from .. import buttons from .. import translations as TR @@ -91,17 +92,19 @@ def prepare( tap = False + Session(device_handler.client.get_management_session()).lock() + # Setup according to the wanted situation if situation == Situation.PIN_INPUT: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore tap = True if situation == Situation.PIN_INPUT_CANCEL: # Any action triggering the PIN dialogue - device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore + device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore elif situation == Situation.PIN_SETUP: # Set new PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore assert ( TR.pin__turn_on in debug.read_layout().text_content() or TR.pin__info in debug.read_layout().text_content() @@ -115,14 +118,14 @@ def prepare( go_next(debug) elif situation == Situation.PIN_CHANGE: # Change PIN - device_handler.run(device.change_pin) # type: ignore + device_handler.run_with_session(device.change_pin) # type: ignore _input_see_confirm(debug, old_pin) assert TR.pin__change in debug.read_layout().text_content() go_next(debug) _input_see_confirm(debug, old_pin) elif situation == Situation.WIPE_CODE_SETUP: # Set wipe code - device_handler.run(device.change_wipe_code) # type: ignore + device_handler.run_with_session(device.change_wipe_code) # type: ignore if old_pin: _input_see_confirm(debug, old_pin) assert TR.wipe_code__turn_on in debug.read_layout().text_content() diff --git a/tests/click_tests/test_recovery.py b/tests/click_tests/test_recovery.py index 8770649296..769b2b507c 100644 --- a/tests/click_tests/test_recovery.py +++ b/tests/click_tests/test_recovery.py @@ -40,7 +40,7 @@ def prepare_recovery_and_evaluate( features = device_handler.features() debug = device_handler.debuglink() assert features.initialized is False - device_handler.run(device.recover, pin_protection=False) # type: ignore + device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore yield debug diff --git a/tests/click_tests/test_repeated_backup.py b/tests/click_tests/test_repeated_backup.py index d61d97962d..55fd4157ef 100644 --- a/tests/click_tests/test_repeated_backup.py +++ b/tests/click_tests/test_repeated_backup.py @@ -42,7 +42,7 @@ def test_repeated_backup( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, strength=128, backup_type=messages.BackupType.Slip39_Basic, @@ -94,7 +94,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # run recovery to unlock backup - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) @@ -161,7 +161,7 @@ def test_repeated_backup( assert features.recovery_status == messages.RecoveryStatus.Nothing # try to unlock backup again... - device_handler.run( + device_handler.run_with_session( device.recover, type=messages.RecoveryType.UnlockRepeatedBackup, ) diff --git a/tests/click_tests/test_reset_bip39.py b/tests/click_tests/test_reset_bip39.py index 907246fb51..18692b1279 100644 --- a/tests/click_tests/test_reset_bip39.py +++ b/tests/click_tests/test_reset_bip39.py @@ -40,7 +40,7 @@ def test_reset_bip39(device_handler: "BackgroundDeviceHandler"): assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, strength=128, backup_type=messages.BackupType.Bip39, diff --git a/tests/click_tests/test_reset_slip39_advanced.py b/tests/click_tests/test_reset_slip39_advanced.py index 874ad7a621..d26a55fb00 100644 --- a/tests/click_tests/test_reset_slip39_advanced.py +++ b/tests/click_tests/test_reset_slip39_advanced.py @@ -52,7 +52,7 @@ def test_reset_slip39_advanced( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, backup_type=messages.BackupType.Slip39_Advanced, pin_protection=False, diff --git a/tests/click_tests/test_reset_slip39_basic.py b/tests/click_tests/test_reset_slip39_basic.py index f8c6592f6d..fbdd8f63f7 100644 --- a/tests/click_tests/test_reset_slip39_basic.py +++ b/tests/click_tests/test_reset_slip39_basic.py @@ -48,7 +48,7 @@ def test_reset_slip39_basic( assert features.initialized is False - device_handler.run( + device_handler.run_with_session( device.reset, strength=128, backup_type=messages.BackupType.Slip39_Basic, diff --git a/tests/click_tests/test_tutorial_mercury.py b/tests/click_tests/test_tutorial_mercury.py index 987b32b48c..7129cf9131 100644 --- a/tests/click_tests/test_tutorial_mercury.py +++ b/tests/click_tests/test_tutorial_mercury.py @@ -36,7 +36,7 @@ pytestmark = [ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -57,7 +57,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -84,7 +84,7 @@ def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -108,7 +108,7 @@ def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 @@ -139,7 +139,7 @@ def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"): debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) layout = debug.read_layout() assert layout.title() == TR.tutorial__welcome_safe5 diff --git a/tests/click_tests/test_tutorial_tr.py b/tests/click_tests/test_tutorial_tr.py index 81d2645ace..88dc895a64 100644 --- a/tests/click_tests/test_tutorial_tr.py +++ b/tests/click_tests/test_tutorial_tr.py @@ -39,7 +39,7 @@ def prepare_tutorial_and_cancel_after_it( device_handler: "BackgroundDeviceHandler", cancelled: bool = False ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() - device_handler.run(device.show_device_tutorial) + device_handler.run_with_session(device.show_device_tutorial) yield debug diff --git a/tests/device_handler.py b/tests/device_handler.py index a380ea80a3..f78686922c 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -33,9 +33,6 @@ class NullUI: else: raise NotImplementedError("NullUI should not be used with T1") - def debug_callback_button(self, session: Any, msg: Any) -> Any: - raise RuntimeError("unexpected call to a fake debuglink") - class BackgroundDeviceHandler: _pool = ThreadPoolExecutor() @@ -78,6 +75,20 @@ class BackgroundDeviceHandler: session = self.client.get_session() self.task = self._pool.submit(function, session, *args, **kwargs) + def run_with_provided_session( + self, session, function: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: + """Runs some function that interacts with a device. + + Makes sure the UI is updated before returning. + """ + if self.task is not None: + raise RuntimeError("Wait for previous task first") + + # wait for the first UI change triggered by the task running in the background + with self.debuglink().wait_for_layout_change(): + self.task = self._pool.submit(function, session, *args, **kwargs) + def kill_task(self) -> None: if self.task is not None: # Force close the client, which should raise an exception in a client @@ -108,6 +119,7 @@ class BackgroundDeviceHandler: def features(self) -> "Features": if self.task is not None: raise RuntimeError("Cannot query features while task is running") + self.client.refresh_features() return self.client.features def debuglink(self) -> "DebugLink": diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 41eeaf770d..ff70180eeb 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -794,7 +794,6 @@ def test_get_address(session: Session): def test_multisession_authorization(client: Client): # Authorize CoinJoin with www.example1.com in session 1. session1 = client.get_session() - session_id_1 = session1.id btc.authorize_coinjoin( session1, @@ -807,7 +806,6 @@ def test_multisession_authorization(client: Client): script_type=messages.InputScriptType.SPENDTAPROOT, ) session2 = client.get_session() - session_id_2 = session2.id # Open a second session. # session_id1 = session.session_id # TODO client.init_device(new_session=True) diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index c88f85167c..ef801d4b41 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -18,8 +18,8 @@ import pytest from trezorlib import cardano, messages, models from trezorlib.btc import get_public_node -from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import SessionDebugWrapper as Session +from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.tools import parse_path @@ -159,7 +159,7 @@ def test_session_recycling(client: Client): ] ) client.use_passphrase("TREZOR") - address = get_test_address(session) + # address = get_test_address(session) # create and close 100 sessions - more than the session limit for _ in range(100): diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 93beed180a..f8949255f4 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -20,8 +20,8 @@ import pytest from trezorlib import device, exceptions, messages from trezorlib.debuglink import LayoutType -from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import SessionDebugWrapper as Session +from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import FailureType, SafetyCheckLevel from trezorlib.tools import parse_path diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index 9a778f6055..f70cde9673 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -20,8 +20,8 @@ import pytest from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib._internal.emulator import Emulator -from trezorlib.tools import parse_path from trezorlib.debuglink import SessionDebugWrapper as Session +from trezorlib.tools import parse_path from ..emulators import EmulatorWrapper from . import for_all