1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-01 11:58:28 +00:00

test(python): fix click_tests

[no changelog]
This commit is contained in:
M1nd3r 2024-11-25 19:02:26 +01:00
parent aaa9dfbb30
commit bb89708a94
23 changed files with 153 additions and 110 deletions

View File

@ -208,6 +208,7 @@ if __debug__:
msg: DebugLinkDecision, msg: DebugLinkDecision,
) -> DebugLinkState | None: ) -> DebugLinkState | None:
from trezor import ui, workflow from trezor import ui, workflow
log.debug(__name__, "decision 1") log.debug(__name__, "decision 1")
workflow.idle_timer.touch() workflow.idle_timer.touch()

View File

@ -51,6 +51,7 @@ LOG = logging.getLogger(__name__)
class TrezorClient: class TrezorClient:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
_management_session: Session | None = None _management_session: Session | None = None
@ -135,8 +136,8 @@ class TrezorClient:
""" """
Note: this function potentially modifies the input session. Note: this function potentially modifies the input session.
""" """
from trezorlib.transport.session import SessionV1, SessionV2 from .debuglink import SessionDebugWrapper
from trezorlib.debuglink import SessionDebugWrapper from .transport.session import SessionV1, SessionV2
if isinstance(session, SessionDebugWrapper): if isinstance(session, SessionDebugWrapper):
session = session._session session = session._session

View File

@ -32,10 +32,10 @@ from pathlib import Path
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import btc, mapping, messages, models, protobuf from . import btc, mapping, messages, models, protobuf
from .client import TrezorClient from .client import MAX_PASSPHRASE_LENGTH, PASSPHRASE_ON_DEVICE, TrezorClient
from .exceptions import TrezorFailure from .exceptions import Cancelled, TrezorFailure
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import DebugWaitType from .messages import Capability, DebugWaitType
from .tools import expect, parse_path from .tools import expect, parse_path
from .transport.session import Session, SessionV1, SessionV2 from .transport.session import Session, SessionV1, SessionV2
from .transport.thp.protocol_v1 import ProtocolV1 from .transport.thp.protocol_v1 import ProtocolV1
@ -553,7 +553,7 @@ class DebugLink:
def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent:
# Next layout change will be caused by external event # Next layout change will be caused by external event
# (e.g. device being auto-locked or as a result of device_handler.run(xxx)) # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx))
# and not by our debug actions/decisions. # and not by our debug actions/decisions.
# Resetting the debug state so we wait for the next layout change # Resetting the debug state so we wait for the next layout change
# (and do not return the current state). # (and do not return the current state).
@ -896,11 +896,6 @@ class DebugUI:
else: else:
self.debuglink.press_yes() self.debuglink.press_yes()
def debug_callback_button(self, session: Session, msg: t.Any) -> t.Any:
session._write(messages.ButtonAck())
self.button_request(msg)
return session._read()
def button_request(self, br: messages.ButtonRequest) -> None: def button_request(self, br: messages.ButtonRequest) -> None:
self.debuglink.snapshot_legacy() self.debuglink.snapshot_legacy()
@ -1337,7 +1332,68 @@ class TrezorClientDebugLink(TrezorClient):
@property @property
def button_callback(self): def button_callback(self):
return self.ui.debug_callback_button
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
return _callback_button
@property
def passphrase_callback(self):
def _callback_passphrase(
session: Session, msg: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> t.Any:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
# session.session_id = resp.state
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
if isinstance(session, SessionV1):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
else:
passphrase = session.passphrase
except Cancelled:
session.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
return _callback_passphrase
def ensure_open(self) -> None: def ensure_open(self) -> None:
"""Only open session if there isn't already an open one.""" """Only open session if there isn't already an open one."""

View File

@ -3,11 +3,6 @@ from __future__ import annotations
import logging import logging
import typing as t import typing as t
from mnemonic import Mnemonic
from ..client import MAX_PASSPHRASE_LENGTH, PASSPHRASE_ON_DEVICE
from ..messages import Capability
from .. import exceptions, messages, models from .. import exceptions, messages, models
from .thp.protocol_v1 import ProtocolV1 from .thp.protocol_v1 import ProtocolV1
from .thp.protocol_v2 import ProtocolV2 from .thp.protocol_v2 import ProtocolV2
@ -48,6 +43,7 @@ class Session:
elif isinstance(resp, messages.PassphraseRequest): elif isinstance(resp, messages.PassphraseRequest):
if self.passphrase_callback is None: if self.passphrase_callback is None:
raise Exception # TODO raise Exception # TODO
print(self.passphrase_callback)
resp = self.passphrase_callback(self, resp) resp = self.passphrase_callback(self, resp)
elif isinstance(resp, messages.ButtonRequest): elif isinstance(resp, messages.ButtonRequest):
if self.button_callback is None: if self.button_callback is None:
@ -95,6 +91,7 @@ class Session:
class SessionV1(Session): class SessionV1(Session):
derive_cardano: bool = False derive_cardano: bool = False
@classmethod @classmethod
def new( def new(
cls, client: TrezorClient, passphrase: str = "", derive_cardano: bool = False cls, client: TrezorClient, passphrase: str = "", derive_cardano: bool = False
@ -108,7 +105,7 @@ class SessionV1(Session):
session = SessionV1(client, session_id) session = SessionV1(client, session_id)
session.button_callback = client.button_callback session.button_callback = client.button_callback
session.pin_callback = client.pin_callback session.pin_callback = client.pin_callback
session.passphrase_callback = _callback_passphrase session.passphrase_callback = client.passphrase_callback
session.passphrase = passphrase session.passphrase = passphrase
session.derive_cardano = derive_cardano session.derive_cardano = derive_cardano
session.init_session() session.init_session()
@ -142,46 +139,6 @@ def _callback_button(session: Session, msg: t.Any) -> t.Any:
return session.call(messages.ButtonAck()) return session.call(messages.ButtonAck())
def _callback_passphrase(session: Session, msg: messages.PassphraseRequest) -> t.Any:
available_on_device = Capability.PassphraseEntry in session.features.capabilities
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> t.Any:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
session.session_id = resp.state
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
passphrase = session.passphrase
except exceptions.Cancelled:
session.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
class SessionV2(Session): class SessionV2(Session):
@classmethod @classmethod

View File

@ -22,6 +22,7 @@ import pytest
from trezorlib import btc, device, exceptions, messages from trezorlib import btc, device, exceptions, messages
from trezorlib.debuglink import LayoutType from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.protobuf import MessageType from trezorlib.protobuf import MessageType
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
@ -58,7 +59,7 @@ CENTER_BUTTON = buttons.grid35(1, 2)
def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int): def set_autolock_delay(device_handler: "BackgroundDeviceHandler", delay_ms: int):
debug = device_handler.debuglink() debug = device_handler.debuglink()
Session(device_handler.client.get_management_session()).lock()
device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=delay_ms) # type: ignore
assert "PinKeyboard" in debug.read_layout().all_components() assert "PinKeyboard" in debug.read_layout().all_components()
@ -97,7 +98,7 @@ def test_autolock_interrupts_signing(device_handler: "BackgroundDeviceHandler"):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
device_handler.run(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore device_handler.run_with_session(btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET) # type: ignore
assert ( assert (
"1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1" "1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1"
@ -132,6 +133,10 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
set_autolock_delay(device_handler, 10_000) set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink() debug = device_handler.debuglink()
# Prepare session to use later
session = Session(device_handler.client.get_session())
# try to sign a transaction # try to sign a transaction
inp1 = messages.TxInputType( inp1 = messages.TxInputType(
address_n=parse_path("86h/0h/0h/0/0"), address_n=parse_path("86h/0h/0h/0/0"),
@ -147,8 +152,8 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
device_handler.run( device_handler.run_with_provided_session(
btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET session, btc.sign_tx, "Bitcoin", [inp1], [out1], prev_txes=TX_CACHE_MAINNET
) )
assert ( assert (
@ -175,11 +180,11 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
def sleepy_filter(msg: MessageType) -> MessageType: def sleepy_filter(msg: MessageType) -> MessageType:
time.sleep(10.1) time.sleep(10.1)
device_handler.client.set_filter(messages.TxAck, None) session.set_filter(messages.TxAck, None)
return msg return msg
with device_handler.client: with session, device_handler.client:
device_handler.client.set_filter(messages.TxAck, sleepy_filter) session.set_filter(messages.TxAck, sleepy_filter)
# confirm transaction # confirm transaction
# In all cases we set wait=False to avoid waiting for the screen and triggering # In all cases we set wait=False to avoid waiting for the screen and triggering
# the layout deadlock detection. In reality there is no deadlock but the # the layout deadlock detection. In reality there is no deadlock but the
@ -187,7 +192,7 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
# timeout is 3. In this test we don't need the result of the input event so # timeout is 3. In this test we don't need the result of the input event so
# waiting for it is not necessary. # waiting for it is not necessary.
if debug.layout_type is LayoutType.TT: if debug.layout_type is LayoutType.TT:
debug.click(buttons.OK, wait=False) debug.click(buttons.OK, hold_ms=1000, wait=False)
elif debug.layout_type is LayoutType.Mercury: elif debug.layout_type is LayoutType.Mercury:
debug.click(buttons.TAP_TO_CONFIRM, wait=False) debug.click(buttons.TAP_TO_CONFIRM, wait=False)
elif debug.layout_type is LayoutType.TR: elif debug.layout_type is LayoutType.TR:
@ -196,7 +201,6 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
signatures, tx = device_handler.result() signatures, tx = device_handler.result()
assert len(signatures) == 1 assert len(signatures) == 1
assert tx assert tx
assert device_handler.features().unlocked is False assert device_handler.features().unlocked is False
@ -206,7 +210,7 @@ def test_autolock_passphrase_keyboard(device_handler: "BackgroundDeviceHandler")
debug = device_handler.debuglink() debug = device_handler.debuglink()
# get address # get address
device_handler.run(common.get_test_address) # type: ignore device_handler.run_with_session(common.get_test_address) # type: ignore
assert "PassphraseKeyboard" in debug.read_layout().all_components() assert "PassphraseKeyboard" in debug.read_layout().all_components()
@ -248,7 +252,7 @@ def test_autolock_interrupts_passphrase(device_handler: "BackgroundDeviceHandler
debug = device_handler.debuglink() debug = device_handler.debuglink()
# get address # get address
device_handler.run(common.get_test_address) # type: ignore device_handler.run_with_session(common.get_test_address) # type: ignore
assert "PassphraseKeyboard" in debug.read_layout().all_components() assert "PassphraseKeyboard" in debug.read_layout().all_components()
@ -287,7 +291,7 @@ def test_dryrun_locks_at_number_of_words(device_handler: "BackgroundDeviceHandle
set_autolock_delay(device_handler, 10_000) set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.recover, type=messages.RecoveryType.DryRun) device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
layout = unlock_dry_run(debug) layout = unlock_dry_run(debug)
assert TR.recovery__num_of_words in debug.read_layout().text_content() assert TR.recovery__num_of_words in debug.read_layout().text_content()
@ -319,7 +323,7 @@ def test_dryrun_locks_at_word_entry(device_handler: "BackgroundDeviceHandler"):
set_autolock_delay(device_handler, 10_000) set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.recover, type=messages.RecoveryType.DryRun) device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
unlock_dry_run(debug) unlock_dry_run(debug)
@ -345,7 +349,7 @@ def test_dryrun_enter_word_slowly(device_handler: "BackgroundDeviceHandler"):
set_autolock_delay(device_handler, 10_000) set_autolock_delay(device_handler, 10_000)
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.recover, type=messages.RecoveryType.DryRun) device_handler.run_with_session(device.recover, type=messages.RecoveryType.DryRun)
unlock_dry_run(debug) unlock_dry_run(debug)
@ -405,7 +409,11 @@ def test_autolock_does_not_interrupt_preauthorized(
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run( # Prepare session to use later
session = Session(device_handler.client.get_session())
device_handler.run_with_provided_session(
session,
btc.authorize_coinjoin, btc.authorize_coinjoin,
coordinator="www.example.com", coordinator="www.example.com",
max_rounds=2, max_rounds=2,
@ -519,14 +527,15 @@ def test_autolock_does_not_interrupt_preauthorized(
def sleepy_filter(msg: MessageType) -> MessageType: def sleepy_filter(msg: MessageType) -> MessageType:
time.sleep(10.1) time.sleep(10.1)
device_handler.client.set_filter(messages.SignTx, None) session.set_filter(messages.SignTx, None)
return msg return msg
with device_handler.client: with session:
# Start DoPreauthorized flow when device is unlocked. Wait 10s before # Start DoPreauthorized flow when device is unlocked. Wait 10s before
# delivering SignTx, by that time autolock timer should have fired. # delivering SignTx, by that time autolock timer should have fired.
device_handler.client.set_filter(messages.SignTx, sleepy_filter) session.set_filter(messages.SignTx, sleepy_filter)
device_handler.run( device_handler.run_with_provided_session(
session,
btc.sign_tx, btc.sign_tx,
"Testnet", "Testnet",
inputs, inputs,

View File

@ -53,7 +53,9 @@ def test_backup_slip39_custom(
assert features.initialized is False assert features.initialized is False
device_handler.run( session = device_handler.client.get_management_session()
device_handler.run_with_provided_session(
session,
device.reset, device.reset,
strength=128, strength=128,
backup_type=messages.BackupType.Slip39_Basic, backup_type=messages.BackupType.Slip39_Basic,
@ -68,7 +70,7 @@ def test_backup_slip39_custom(
assert device_handler.result() == "Initialized" assert device_handler.result() == "Initialized"
device_handler.run( device_handler.run_with_session(
device.backup, device.backup,
group_threshold=group_threshold, group_threshold=group_threshold,
groups=[(share_threshold, share_count)], groups=[(share_threshold, share_count)],

View File

@ -65,20 +65,22 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
device_handler.run_with_session(common.get_test_address) device_handler.run_with_session(common.get_test_address)
assert "PinKeyboard" in debug.read_layout().all_components() assert "PinKeyboard" in debug.read_layout().all_components()
time.sleep(10)
debug.input("1234") debug.input("1234")
assert device_handler.result() assert device_handler.result()
session.refresh_features()
assert device_handler.features().unlocked is True assert device_handler.features().unlocked is True
# short touch # short touch
hold(short_duration) hold(short_duration)
time.sleep(0.5) # so that the homescreen appears again (hacky) time.sleep(0.5) # so that the homescreen appears again (hacky)
session.refresh_features()
assert device_handler.features().unlocked is True assert device_handler.features().unlocked is True
# lock # lock
hold(lock_duration) hold(lock_duration)
session.refresh_features()
assert device_handler.features().unlocked is False assert device_handler.features().unlocked is False
# unlock by touching # unlock by touching
@ -89,8 +91,10 @@ def test_hold_to_lock(device_handler: "BackgroundDeviceHandler"):
assert "PinKeyboard" in layout.all_components() assert "PinKeyboard" in layout.all_components()
debug.input("1234") debug.input("1234")
session.refresh_features()
assert device_handler.features().unlocked is True assert device_handler.features().unlocked is True
# lock # lock
hold(lock_duration) hold(lock_duration)
session.refresh_features()
assert device_handler.features().unlocked is False assert device_handler.features().unlocked is False

View File

@ -97,7 +97,7 @@ def prepare_passphrase_dialogue(
device_handler: "BackgroundDeviceHandler", address: Optional[str] = None device_handler: "BackgroundDeviceHandler", address: Optional[str] = None
) -> Generator["DebugLink", None, None]: ) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(get_test_address) # type: ignore device_handler.run_with_session(get_test_address) # type: ignore
# TODO # TODO
assert debug.read_layout().main_component() == "PassphraseKeyboard" assert debug.read_layout().main_component() == "PassphraseKeyboard"

View File

@ -91,7 +91,7 @@ def prepare_passphrase_dialogue(
device_handler: "BackgroundDeviceHandler", address: Optional[str] = None device_handler: "BackgroundDeviceHandler", address: Optional[str] = None
) -> Generator["DebugLink", None, None]: ) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(get_test_address) # type: ignore device_handler.run_with_session(get_test_address) # type: ignore
layout = debug.read_layout() layout = debug.read_layout()
assert "PassphraseKeyboard" in layout.all_components() assert "PassphraseKeyboard" in layout.all_components()
assert layout.passphrase() == "" assert layout.passphrase() == ""

View File

@ -69,7 +69,7 @@ def prepare_passphrase_dialogue(
device_handler: "BackgroundDeviceHandler", address: Optional[str] = None device_handler: "BackgroundDeviceHandler", address: Optional[str] = None
) -> Generator["DebugLink", None, None]: ) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(get_test_address) # type: ignore device_handler.run_with_session(get_test_address) # type: ignore
assert debug.read_layout().main_component() == "PassphraseKeyboard" assert debug.read_layout().main_component() == "PassphraseKeyboard"
# Resetting the category as it could have been changed by previous tests # Resetting the category as it could have been changed by previous tests

View File

@ -23,6 +23,7 @@ import pytest
from trezorlib import device, exceptions from trezorlib import device, exceptions
from trezorlib.debuglink import LayoutType from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from .. import buttons from .. import buttons
from .. import translations as TR from .. import translations as TR
@ -91,17 +92,19 @@ def prepare(
tap = False tap = False
Session(device_handler.client.get_management_session()).lock()
# Setup according to the wanted situation # Setup according to the wanted situation
if situation == Situation.PIN_INPUT: if situation == Situation.PIN_INPUT:
# Any action triggering the PIN dialogue # Any action triggering the PIN dialogue
device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
tap = True tap = True
if situation == Situation.PIN_INPUT_CANCEL: if situation == Situation.PIN_INPUT_CANCEL:
# Any action triggering the PIN dialogue # Any action triggering the PIN dialogue
device_handler.run(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore device_handler.run_with_session(device.apply_settings, auto_lock_delay_ms=300_000) # type: ignore
elif situation == Situation.PIN_SETUP: elif situation == Situation.PIN_SETUP:
# Set new PIN # Set new PIN
device_handler.run(device.change_pin) # type: ignore device_handler.run_with_session(device.change_pin) # type: ignore
assert ( assert (
TR.pin__turn_on in debug.read_layout().text_content() TR.pin__turn_on in debug.read_layout().text_content()
or TR.pin__info in debug.read_layout().text_content() or TR.pin__info in debug.read_layout().text_content()
@ -115,14 +118,14 @@ def prepare(
go_next(debug) go_next(debug)
elif situation == Situation.PIN_CHANGE: elif situation == Situation.PIN_CHANGE:
# Change PIN # Change PIN
device_handler.run(device.change_pin) # type: ignore device_handler.run_with_session(device.change_pin) # type: ignore
_input_see_confirm(debug, old_pin) _input_see_confirm(debug, old_pin)
assert TR.pin__change in debug.read_layout().text_content() assert TR.pin__change in debug.read_layout().text_content()
go_next(debug) go_next(debug)
_input_see_confirm(debug, old_pin) _input_see_confirm(debug, old_pin)
elif situation == Situation.WIPE_CODE_SETUP: elif situation == Situation.WIPE_CODE_SETUP:
# Set wipe code # Set wipe code
device_handler.run(device.change_wipe_code) # type: ignore device_handler.run_with_session(device.change_wipe_code) # type: ignore
if old_pin: if old_pin:
_input_see_confirm(debug, old_pin) _input_see_confirm(debug, old_pin)
assert TR.wipe_code__turn_on in debug.read_layout().text_content() assert TR.wipe_code__turn_on in debug.read_layout().text_content()

View File

@ -40,7 +40,7 @@ def prepare_recovery_and_evaluate(
features = device_handler.features() features = device_handler.features()
debug = device_handler.debuglink() debug = device_handler.debuglink()
assert features.initialized is False assert features.initialized is False
device_handler.run(device.recover, pin_protection=False) # type: ignore device_handler.run_with_session(device.recover, pin_protection=False) # type: ignore
yield debug yield debug

View File

@ -42,7 +42,7 @@ def test_repeated_backup(
assert features.initialized is False assert features.initialized is False
device_handler.run( device_handler.run_with_session(
device.reset, device.reset,
strength=128, strength=128,
backup_type=messages.BackupType.Slip39_Basic, backup_type=messages.BackupType.Slip39_Basic,
@ -94,7 +94,7 @@ def test_repeated_backup(
assert features.recovery_status == messages.RecoveryStatus.Nothing assert features.recovery_status == messages.RecoveryStatus.Nothing
# run recovery to unlock backup # run recovery to unlock backup
device_handler.run( device_handler.run_with_session(
device.recover, device.recover,
type=messages.RecoveryType.UnlockRepeatedBackup, type=messages.RecoveryType.UnlockRepeatedBackup,
) )
@ -161,7 +161,7 @@ def test_repeated_backup(
assert features.recovery_status == messages.RecoveryStatus.Nothing assert features.recovery_status == messages.RecoveryStatus.Nothing
# try to unlock backup again... # try to unlock backup again...
device_handler.run( device_handler.run_with_session(
device.recover, device.recover,
type=messages.RecoveryType.UnlockRepeatedBackup, type=messages.RecoveryType.UnlockRepeatedBackup,
) )

View File

@ -40,7 +40,7 @@ def test_reset_bip39(device_handler: "BackgroundDeviceHandler"):
assert features.initialized is False assert features.initialized is False
device_handler.run( device_handler.run_with_session(
device.reset, device.reset,
strength=128, strength=128,
backup_type=messages.BackupType.Bip39, backup_type=messages.BackupType.Bip39,

View File

@ -52,7 +52,7 @@ def test_reset_slip39_advanced(
assert features.initialized is False assert features.initialized is False
device_handler.run( device_handler.run_with_session(
device.reset, device.reset,
backup_type=messages.BackupType.Slip39_Advanced, backup_type=messages.BackupType.Slip39_Advanced,
pin_protection=False, pin_protection=False,

View File

@ -48,7 +48,7 @@ def test_reset_slip39_basic(
assert features.initialized is False assert features.initialized is False
device_handler.run( device_handler.run_with_session(
device.reset, device.reset,
strength=128, strength=128,
backup_type=messages.BackupType.Slip39_Basic, backup_type=messages.BackupType.Slip39_Basic,

View File

@ -36,7 +36,7 @@ pytestmark = [
def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"): def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial) device_handler.run_with_session(device.show_device_tutorial)
layout = debug.read_layout() layout = debug.read_layout()
assert layout.title() == TR.tutorial__welcome_safe5 assert layout.title() == TR.tutorial__welcome_safe5
@ -57,7 +57,7 @@ def test_tutorial_ignore_menu(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial) device_handler.run_with_session(device.show_device_tutorial)
layout = debug.read_layout() layout = debug.read_layout()
assert layout.title() == TR.tutorial__welcome_safe5 assert layout.title() == TR.tutorial__welcome_safe5
@ -84,7 +84,7 @@ def test_tutorial_menu_open_close(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial) device_handler.run_with_session(device.show_device_tutorial)
layout = debug.read_layout() layout = debug.read_layout()
assert layout.title() == TR.tutorial__welcome_safe5 assert layout.title() == TR.tutorial__welcome_safe5
@ -108,7 +108,7 @@ def test_tutorial_menu_exit(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial) device_handler.run_with_session(device.show_device_tutorial)
layout = debug.read_layout() layout = debug.read_layout()
assert layout.title() == TR.tutorial__welcome_safe5 assert layout.title() == TR.tutorial__welcome_safe5
@ -139,7 +139,7 @@ def test_tutorial_menu_repeat(device_handler: "BackgroundDeviceHandler"):
def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"): def test_tutorial_menu_funfact(device_handler: "BackgroundDeviceHandler"):
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial) device_handler.run_with_session(device.show_device_tutorial)
layout = debug.read_layout() layout = debug.read_layout()
assert layout.title() == TR.tutorial__welcome_safe5 assert layout.title() == TR.tutorial__welcome_safe5

View File

@ -39,7 +39,7 @@ def prepare_tutorial_and_cancel_after_it(
device_handler: "BackgroundDeviceHandler", cancelled: bool = False device_handler: "BackgroundDeviceHandler", cancelled: bool = False
) -> Generator["DebugLink", None, None]: ) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run(device.show_device_tutorial) device_handler.run_with_session(device.show_device_tutorial)
yield debug yield debug

View File

@ -33,9 +33,6 @@ class NullUI:
else: else:
raise NotImplementedError("NullUI should not be used with T1") raise NotImplementedError("NullUI should not be used with T1")
def debug_callback_button(self, session: Any, msg: Any) -> Any:
raise RuntimeError("unexpected call to a fake debuglink")
class BackgroundDeviceHandler: class BackgroundDeviceHandler:
_pool = ThreadPoolExecutor() _pool = ThreadPoolExecutor()
@ -78,6 +75,20 @@ class BackgroundDeviceHandler:
session = self.client.get_session() session = self.client.get_session()
self.task = self._pool.submit(function, session, *args, **kwargs) self.task = self._pool.submit(function, session, *args, **kwargs)
def run_with_provided_session(
self, session, function: Callable[..., Any], *args: Any, **kwargs: Any
) -> None:
"""Runs some function that interacts with a device.
Makes sure the UI is updated before returning.
"""
if self.task is not None:
raise RuntimeError("Wait for previous task first")
# wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, session, *args, **kwargs)
def kill_task(self) -> None: def kill_task(self) -> None:
if self.task is not None: if self.task is not None:
# Force close the client, which should raise an exception in a client # Force close the client, which should raise an exception in a client
@ -108,6 +119,7 @@ class BackgroundDeviceHandler:
def features(self) -> "Features": def features(self) -> "Features":
if self.task is not None: if self.task is not None:
raise RuntimeError("Cannot query features while task is running") raise RuntimeError("Cannot query features while task is running")
self.client.refresh_features()
return self.client.features return self.client.features
def debuglink(self) -> "DebugLink": def debuglink(self) -> "DebugLink":

View File

@ -794,7 +794,6 @@ def test_get_address(session: Session):
def test_multisession_authorization(client: Client): def test_multisession_authorization(client: Client):
# Authorize CoinJoin with www.example1.com in session 1. # Authorize CoinJoin with www.example1.com in session 1.
session1 = client.get_session() session1 = client.get_session()
session_id_1 = session1.id
btc.authorize_coinjoin( btc.authorize_coinjoin(
session1, session1,
@ -807,7 +806,6 @@ def test_multisession_authorization(client: Client):
script_type=messages.InputScriptType.SPENDTAPROOT, script_type=messages.InputScriptType.SPENDTAPROOT,
) )
session2 = client.get_session() session2 = client.get_session()
session_id_2 = session2.id
# Open a second session. # Open a second session.
# session_id1 = session.session_id # session_id1 = session.session_id
# TODO client.init_device(new_session=True) # TODO client.init_device(new_session=True)

View File

@ -18,8 +18,8 @@ import pytest
from trezorlib import cardano, messages, models from trezorlib import cardano, messages, models
from trezorlib.btc import get_public_node from trezorlib.btc import get_public_node
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
@ -159,7 +159,7 @@ def test_session_recycling(client: Client):
] ]
) )
client.use_passphrase("TREZOR") client.use_passphrase("TREZOR")
address = get_test_address(session) # address = get_test_address(session)
# create and close 100 sessions - more than the session limit # create and close 100 sessions - more than the session limit
for _ in range(100): for _ in range(100):

View File

@ -20,8 +20,8 @@ import pytest
from trezorlib import device, exceptions, messages from trezorlib import device, exceptions, messages
from trezorlib.debuglink import LayoutType from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import FailureType, SafetyCheckLevel from trezorlib.messages import FailureType, SafetyCheckLevel
from trezorlib.tools import parse_path from trezorlib.tools import parse_path

View File

@ -20,8 +20,8 @@ import pytest
from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib._internal.emulator import Emulator from trezorlib._internal.emulator import Emulator
from trezorlib.tools import parse_path
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
from ..emulators import EmulatorWrapper from ..emulators import EmulatorWrapper
from . import for_all from . import for_all