From a129a05afd28cec0247d3eeead51576e189a7f7d Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Thu, 6 Mar 2025 00:35:51 +0100 Subject: [PATCH] refactor(tests): move set_input_flow to SessionDebugWrapper context manager [no changelog] --- python/src/trezorlib/debuglink.py | 134 ++++++------------ tests/burn_tests/burntest_t2.py | 8 +- tests/click_tests/test_autolock.py | 2 +- tests/conftest.py | 2 +- tests/device_handler.py | 4 + .../device_tests/binance/test_get_address.py | 6 +- .../binance/test_get_public_key.py | 6 +- .../bitcoin/test_authorize_coinjoin.py | 4 +- .../device_tests/bitcoin/test_descriptors.py | 14 +- tests/device_tests/bitcoin/test_getaddress.py | 36 ++--- .../bitcoin/test_getaddress_segwit.py | 6 +- .../bitcoin/test_getaddress_show.py | 38 ++--- .../device_tests/bitcoin/test_getpublickey.py | 10 +- tests/device_tests/bitcoin/test_multisig.py | 6 +- .../bitcoin/test_multisig_change.py | 12 +- .../bitcoin/test_nonstandard_paths.py | 30 ++-- .../device_tests/bitcoin/test_signmessage.py | 24 ++-- tests/device_tests/bitcoin/test_signtx.py | 42 +++--- .../bitcoin/test_signtx_invalid_path.py | 12 +- .../bitcoin/test_signtx_payreq.py | 6 +- .../bitcoin/test_signtx_prevhash.py | 12 +- .../bitcoin/test_signtx_segwit_native.py | 24 ++-- .../bitcoin/test_verifymessage.py | 6 +- .../cardano/test_address_public_key.py | 8 +- tests/device_tests/cardano/test_sign_tx.py | 8 +- tests/device_tests/eos/test_get_public_key.py | 6 +- .../device_tests/ethereum/test_definitions.py | 12 +- .../device_tests/ethereum/test_getaddress.py | 6 +- .../ethereum/test_sign_typed_data.py | 16 +-- .../ethereum/test_sign_verify_message.py | 12 +- tests/device_tests/ethereum/test_signtx.py | 24 ++-- .../misc/test_msg_enablelabeling.py | 4 +- tests/device_tests/monero/test_getaddress.py | 6 +- .../test_recovery_bip39_dryrun.py | 14 +- .../reset_recovery/test_recovery_bip39_t2.py | 12 +- .../test_recovery_slip39_advanced.py | 28 ++-- .../test_recovery_slip39_advanced_dryrun.py | 12 +- .../test_recovery_slip39_basic.py | 60 ++++---- .../test_recovery_slip39_basic_dryrun.py | 12 +- .../reset_recovery/test_reset_backup.py | 24 ++-- .../reset_recovery/test_reset_bip39_t2.py | 28 ++-- .../test_reset_recovery_bip39.py | 14 +- .../test_reset_recovery_slip39_advanced.py | 12 +- .../test_reset_recovery_slip39_basic.py | 12 +- .../test_reset_slip39_advanced.py | 6 +- .../reset_recovery/test_reset_slip39_basic.py | 12 +- tests/device_tests/ripple/test_get_address.py | 6 +- tests/device_tests/solana/test_sign_tx.py | 6 +- tests/device_tests/stellar/test_stellar.py | 6 +- tests/device_tests/test_autolock.py | 16 +-- tests/device_tests/test_busy_state.py | 4 +- tests/device_tests/test_cancel.py | 4 +- tests/device_tests/test_debuglink.py | 6 +- tests/device_tests/test_language.py | 10 +- tests/device_tests/test_msg_applysettings.py | 16 +-- tests/device_tests/test_msg_backup_device.py | 31 ++-- .../test_msg_change_wipe_code_t1.py | 32 ++--- .../test_msg_change_wipe_code_t2.py | 46 +++--- tests/device_tests/test_msg_changepin_t1.py | 28 ++-- tests/device_tests/test_msg_changepin_t2.py | 42 +++--- tests/device_tests/test_msg_wipedevice.py | 26 ++-- tests/device_tests/test_pin.py | 18 +-- tests/device_tests/test_protection_levels.py | 62 ++++---- tests/device_tests/test_repeated_backup.py | 54 +++---- tests/device_tests/test_sdcard.py | 22 +-- tests/device_tests/test_session.py | 10 +- .../test_session_id_and_passphrase.py | 14 +- tests/device_tests/tezos/test_getaddress.py | 6 +- .../webauthn/test_msg_webauthn.py | 6 +- tests/input_flows.py | 12 +- tests/persistence_tests/test_wipe_code.py | 16 ++- tests/upgrade_tests/test_firmware_upgrades.py | 9 +- 72 files changed, 632 insertions(+), 668 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 40df48864e..99835efbdc 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -796,10 +796,10 @@ class DebugUI: def __init__(self, debuglink: DebugLink) -> None: self.debuglink = debuglink + self.pins: t.Iterator[str] | None = None self.clear() def clear(self) -> None: - self.pins: t.Iterator[str] | None = None self.passphrase = None self.input_flow: t.Union[ t.Generator[None, messages.ButtonRequest, None], object, None @@ -971,12 +971,10 @@ class SessionDebugWrapper(Session): return self.client.protocol_version def _write(self, msg: t.Any) -> None: - print("writing message:", msg.__class__.__name__) self._session._write(self._filter_message(msg)) def _read(self) -> t.Any: resp = self._filter_message(self._session._read()) - print("reading message:", resp.__class__.__name__) if self.actual_responses is not None: self.actual_responses.append(resp) return resp @@ -1074,6 +1072,7 @@ class SessionDebugWrapper(Session): Clears all debugging state that might have been modified by a testcase. """ + self.client.ui.clear() # type: ignore [Cannot access attribute] self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None @@ -1110,7 +1109,6 @@ class SessionDebugWrapper(Session): # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. @@ -1170,6 +1168,45 @@ class SessionDebugWrapper(Session): output.append("") return output + def set_input_flow( + self, + input_flow: InputFlowType | t.Callable[[], InputFlowType], + ) -> None: + """Configure a sequence of input events for the current with-block. + + The `input_flow` must be a generator function. A `yield` statement in the + input flow function waits for a ButtonRequest from the device, and returns + its code. + + Example usage: + + >>> def input_flow(): + >>> # wait for first button prompt + >>> code = yield + >>> assert code == ButtonRequestType.Other + >>> # press No + >>> client.debug.press_no() + >>> + >>> # wait for second button prompt + >>> yield + >>> # press Yes + >>> client.debug.press_yes() + >>> + >>> with session: + >>> session.set_input_flow(input_flow) + >>> some_call(session) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + if callable(input_flow): + input_flow = input_flow() + if not hasattr(input_flow, "send"): + raise RuntimeError("input_flow should be a generator function") + self.client.ui.input_flow = input_flow # type: ignore [Cannot access attribute] + + next(input_flow) # start the generator + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses @@ -1211,7 +1248,6 @@ class TrezorClientDebugLink(TrezorClient): self.transport = transport self.ui: DebugUI = DebugUI(self.debug) - self.reset_debug_features() self._seedless_session = self.get_seedless_session(new_session=True) self.sync_responses() @@ -1236,15 +1272,6 @@ class TrezorClientDebugLink(TrezorClient): new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter return new_client - def reset_debug_features(self) -> None: - """ - Prepare the debugging client for a new testcase. - - Clears all debugging state that might have been modified by a testcase. - """ - self.ui: DebugUI = DebugUI(self.debug) - self.in_with_statement = False - def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # do this raw - send ButtonAck first, notify UI later @@ -1373,43 +1400,6 @@ class TrezorClientDebugLink(TrezorClient): else: return SessionDebugWrapper(super().resume_session(session)) - def set_input_flow( - self, input_flow: InputFlowType | t.Callable[[], InputFlowType] - ) -> None: - """Configure a sequence of input events for the current with-block. - - The `input_flow` must be a generator function. A `yield` statement in the - input flow function waits for a ButtonRequest from the device, and returns - its code. - - Example usage: - - >>> def input_flow(): - >>> # wait for first button prompt - >>> code = yield - >>> assert code == ButtonRequestType.Other - >>> # press No - >>> client.debug.press_no() - >>> - >>> # wait for second button prompt - >>> yield - >>> # press Yes - >>> client.debug.press_yes() - >>> - >>> with client: - >>> client.set_input_flow(input_flow) - >>> some_call(client) - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - if callable(input_flow): - input_flow = input_flow() - if not hasattr(input_flow, "send"): - raise RuntimeError("input_flow should be a generator function") - self.ui.input_flow = input_flow - next(input_flow) # start the generator - def watch_layout(self, watch: bool = True) -> None: """Enable or disable watching layout changes. @@ -1423,29 +1413,6 @@ class TrezorClientDebugLink(TrezorClient): # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug self.debug.watch_layout(watch) - def __enter__(self) -> "TrezorClientDebugLink": - # For usage in with/expected_responses - if self.in_with_statement: - raise RuntimeError("Do not nest!") - self.in_with_statement = True - return self - - def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - # grab a copy of the inputflow generator to raise an exception through it - if isinstance(self.ui, DebugUI): - input_flow = self.ui.input_flow - else: - input_flow = None - - self.reset_debug_features() - - if exc_type is not None and isinstance(input_flow, t.Generator): - # Propagate the exception through the input flow, so that we see in - # traceback where it is stuck. - input_flow.throw(exc_type, value, traceback) - def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. @@ -1457,25 +1424,6 @@ class TrezorClientDebugLink(TrezorClient): Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - @staticmethod - def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: - start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) - stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) - output: list[str] = [] - output.append("Expected responses:") - if start_at > 0: - output.append(f" (...{start_at} previous responses omitted)") - for i in range(start_at, stop_at): - exp = expected[i] - prefix = " " if i != current else ">>> " - output.append(textwrap.indent(exp.to_string(), prefix)) - if stop_at < len(expected): - omitted = len(expected) - stop_at - output.append(f" (...{omitted} following responses omitted)") - - output.append("") - return output - def sync_responses(self) -> None: """Synchronize Trezor device receiving with caller. diff --git a/tests/burn_tests/burntest_t2.py b/tests/burn_tests/burntest_t2.py index 5f1048254c..98f47424f6 100755 --- a/tests/burn_tests/burntest_t2.py +++ b/tests/burn_tests/burntest_t2.py @@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str): if __name__ == "__main__": wirelink = get_device() client = Client(wirelink) - client.open() + session = client.get_seedless_session() i = 0 @@ -76,10 +76,12 @@ if __name__ == "__main__": # change PIN new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10))) - client.set_input_flow(pin_input_flow(client, last_pin, new_pin)) + session.set_input_flow(pin_input_flow(client, last_pin, new_pin)) device.change_pin(client) - client.set_input_flow(None) + session.set_input_flow(None) last_pin = new_pin print(f"iteration {i}") i = i + 1 + + wirelink.close() diff --git a/tests/click_tests/test_autolock.py b/tests/click_tests/test_autolock.py index 98a5bfd87d..8cf8b1d1de 100644 --- a/tests/click_tests/test_autolock.py +++ b/tests/click_tests/test_autolock.py @@ -198,7 +198,7 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa session.set_filter(messages.TxAck, None) return msg - with session, device_handler.client: + with session: session.set_filter(messages.TxAck, sleepy_filter) # confirm transaction if debug.layout_type is LayoutType.Bolt: diff --git a/tests/conftest.py b/tests/conftest.py index 94cf24ded9..7cce0be359 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -288,7 +288,7 @@ def _client_unlocked( test_ui = request.config.getoption("ui") - _raw_client.reset_debug_features() + # _raw_client.reset_debug_features() if isinstance(_raw_client.protocol, ProtocolV1Channel): try: _raw_client.sync_responses() diff --git a/tests/device_handler.py b/tests/device_handler.py index cf8a8e06fd..9edf0b560d 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1 class NullUI: + @staticmethod + def clear(*args, **kwargs): + pass + @staticmethod def button_request(code): pass diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index 6b5a024767..d242a9bb53 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -51,9 +51,9 @@ def test_binance_get_address_chunkify_details( ): # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index f65baa5dd8..a4ed06a6c1 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -32,9 +32,9 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0") mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" ) def test_binance_get_public_key(session: Session): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( sig.hex() diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 0da918e417..f6454726f7 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -65,8 +65,8 @@ def test_sign_tx(session: Session, chunkify: bool): assert session.features.unlocked is False commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") - with session.client as client: - client.use_pin_sequence([PIN]) + with session: + session.client.use_pin_sequence([PIN]) btc.authorize_coinjoin( session, coordinator="www.example.com", diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 7a077b2052..ab0d536cf2 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -168,9 +168,9 @@ def _address_n(purpose, coin, account, script_type): def test_descriptors( session: Session, coin, account, purpose, script_type, descriptors ): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) res = btc.get_public_node( @@ -191,10 +191,10 @@ def test_descriptors( def test_descriptors_trezorlib( session: Session, coin, account, purpose, script_type, descriptors ): - with session.client as client: - if client.model != models.T1B1: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + if session.client.model != models.T1B1: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) res = btc_cli._get_descriptor( session, coin, account, purpose, script_type, show_display=True ) diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 3c8a2fbc9d..41c8712f04 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -270,10 +270,10 @@ def test_multisig(session: Session): xpubs.append(node.xpub) for nr in range(1, 4): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -321,10 +321,10 @@ def test_multisig_missing(session: Session, show_display): ) for multisig in (multisig1, multisig2): - with session.client as client, pytest.raises(TrezorFailure): + with pytest.raises(TrezorFailure), session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", @@ -345,10 +345,10 @@ def test_bch_multisig(session: Session): xpubs.append(node.xpub) for nr in range(1, 4): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -406,7 +406,7 @@ def test_unknown_path(session: Session): # disable safety checks device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with session, session.client as client: + with session: session.set_expected_responses( [ messages.ButtonRequest( @@ -417,8 +417,8 @@ def test_unknown_path(session: Session): ] ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) # try again with a warning btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) @@ -455,10 +455,10 @@ def test_multisig_different_paths(session: Session): with pytest.raises( Exception, match="Using different paths for different xpubs is not allowed" ): - with session.client as client, session: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", @@ -469,10 +469,10 @@ def test_multisig_different_paths(session: Session): ) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index b1e3affac7..d81649e4dd 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -74,10 +74,10 @@ def test_show_segwit(session: Session): @pytest.mark.altcoin def test_show_segwit_altcoin(session: Session): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 464c9cc70e..bd04712e99 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -63,9 +63,9 @@ def test_show_t1( yield session.client.debug.press_yes() - with session.client as client: + with session: # This is the only place where even T1 is using input flow - client.set_input_flow(input_flow_t1) + session.set_input_flow(input_flow_t1) assert ( btc.get_address( session, @@ -88,9 +88,9 @@ def test_show_tt( script_type: messages.InputScriptType, address: str, ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -109,9 +109,9 @@ def test_show_tt( def test_show_cancel( session: Session, path: str, script_type: messages.InputScriptType, address: str ): - with session.client as client, pytest.raises(Cancelled): - IF = InputFlowShowAddressQRCodeCancel(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(Cancelled): + IF = InputFlowShowAddressQRCodeCancel(session.client) + session.set_input_flow(IF.get()) btc.get_address( session, "Bitcoin", @@ -157,10 +157,10 @@ def test_show_multisig_3(session: Session): for multisig in (multisig1, multisig2): for i in [1, 2, 3]: - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, @@ -273,11 +273,11 @@ def test_show_multisig_xpubs( ) for i in range(3): - with session, session.client as client: - IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) - client.set_input_flow(IF.get()) - client.debug.synchronize_at("Homescreen") - client.watch_layout() + with session: + IF = InputFlowShowMultisigXPUBs(session.client, address, xpubs, i) + session.set_input_flow(IF.get()) + session.client.debug.synchronize_at("Homescreen") + session.client.watch_layout() btc.get_address( session, "Bitcoin", @@ -314,10 +314,10 @@ def test_show_multisig_15(session: Session): for multisig in [multisig1, multisig2]: for i in range(15): - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) assert ( btc.get_address( session, diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index e013e6f71c..8009e35ac7 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -119,9 +119,9 @@ def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): @pytest.mark.models("core") @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub @@ -158,14 +158,14 @@ def test_get_public_node_show_legacy( client.debug.press_yes() # finish the flow yield - with client: + with session: # test XPUB display flow (without showing QR code) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub # test XPUB QR code display using the input flow above - client.set_input_flow(input_flow) + session.set_input_flow(input_flow) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub assert bip32.serialize(res.node, xpub_magic) == xpub diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 5888409d86..4f5b87f044 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -475,10 +475,10 @@ def test_attack_change_input(session: Session): ) # Transaction can be signed without the attack processor - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, "Testnet", diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index efc4f42d56..cef1479e3e 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -288,7 +288,7 @@ def test_external_internal(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, session.client as client: + with session: session.set_expected_responses( _responses( session, @@ -299,8 +299,8 @@ def test_external_internal(session: Session): ) ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, "Bitcoin", @@ -324,7 +324,7 @@ def test_internal_external(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, session.client as client: + with session: session.set_expected_responses( _responses( session, @@ -335,8 +335,8 @@ def test_internal_external(session: Session): ) ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, "Bitcoin", diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index 77d57aa951..d4e5ac1350 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -113,10 +113,10 @@ def test_getaddress( script_types: list[messages.InputScriptType], ): for script_type in script_types: - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) res = btc.get_address( session, "Bitcoin", @@ -134,10 +134,10 @@ def test_signmessage( session: Session, path: str, script_types: list[messages.InputScriptType] ): for script_type in script_types: - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) sig = btc.sign_message( session, @@ -175,10 +175,10 @@ def test_signtx( script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) @@ -202,10 +202,10 @@ def test_getaddress_multisig( ] multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) address = btc.get_address( session, "Bitcoin", @@ -261,10 +261,10 @@ def test_signtx_multisig(session: Session, paths: list[str], address_index: list script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) sig, _ = btc.sign_tx( session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} ) diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index bf9ec4e326..52b9fb7bbf 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -327,9 +327,9 @@ def test_signmessage_long( message: str, signature: str, ): - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client) + session.set_input_flow(IF.get()) sig = btc.sign_message( session, coin_name=coin_name, @@ -356,9 +356,9 @@ def test_signmessage_info( message: str, signature: str, ): - with session.client as client, pytest.raises(Cancelled): - IF = InputFlowSignMessageInfo(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(Cancelled): + IF = InputFlowSignMessageInfo(session.client) + session.set_input_flow(IF.get()) sig = btc.sign_message( session, coin_name=coin_name, @@ -390,13 +390,13 @@ MESSAGE_LENGTHS = ( @pytest.mark.models("core") @pytest.mark.parametrize("message,is_long", MESSAGE_LENGTHS) def test_signmessage_pagination(session: Session, message: str, is_long: bool): - with session.client as client: + with session: IF = ( InputFlowSignVerifyMessageLong if is_long else InputFlowSignMessagePagination - )(client) - client.set_input_flow(IF.get()) + )(session.client) + session.set_input_flow(IF.get()) btc.sign_message( session, coin_name="Bitcoin", @@ -438,7 +438,7 @@ def test_signmessage_pagination_trailing_newline(session: Session): def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with session, session.client as client: + with session: session.set_expected_responses( [ # expect a path warning @@ -451,8 +451,8 @@ def test_signmessage_path_warning(session: Session): ] ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_message( session, coin_name="Bitcoin", diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index 216e928926..122a1cdee5 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -664,9 +664,9 @@ def test_fee_high_hardfail(session: Session): device.apply_settings( session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with session.client as client: - IF = InputFlowSignTxHighFee(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxHighFee(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET @@ -1467,9 +1467,9 @@ def test_lock_time_blockheight(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: - IF = InputFlowLockTimeBlockHeight(client, "499999999") - client.set_input_flow(IF.get()) + with session: + IF = InputFlowLockTimeBlockHeight(session.client, "499999999") + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1506,9 +1506,9 @@ def test_lock_time_datetime(session: Session, lock_time_str: str): lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_timestamp = int(lock_time_utc.timestamp()) - with session.client as client: - IF = InputFlowLockTimeDatetime(client, lock_time_str) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowLockTimeDatetime(session.client, lock_time_str) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1538,9 +1538,9 @@ def test_information(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: - IF = InputFlowSignTxInformation(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxInformation(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1573,9 +1573,9 @@ def test_information_mixed(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: - IF = InputFlowSignTxInformationMixed(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxInformationMixed(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1604,9 +1604,9 @@ def test_information_cancel(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client, pytest.raises(Cancelled): - IF = InputFlowSignTxInformationCancel(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(Cancelled): + IF = InputFlowSignTxInformationCancel(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, @@ -1654,9 +1654,9 @@ def test_information_replacement(session: Session): orig_index=0, ) - with session.client as client: - IF = InputFlowSignTxInformationReplacement(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignTxInformationReplacement(session.client) + session.set_input_flow(IF.get()) btc.sign_tx( session, diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 27f0599de9..f41702047a 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -80,10 +80,10 @@ def test_invalid_path_prompt(session: Session): session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) @@ -106,10 +106,10 @@ def test_invalid_path_pass_forkid(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session.client as client: + with session: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index 32c90d05e0..3f900bb05e 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -203,9 +203,9 @@ def test_payment_request_details(session: Session): ) ] - with session.client as client: - IF = InputFlowPaymentRequestDetails(client, outputs) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowPaymentRequestDetails(session.client, outputs) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index a2f96c04ed..7ae24249ee 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -130,11 +130,11 @@ def test_invalid_prev_hash_attack(session: Session, prev_hash): msg.tx.inputs[0].prev_hash = prev_hash return msg - with session, session.client as client, pytest.raises(TrezorFailure) as e: + with session, pytest.raises(TrezorFailure) as e: session.set_filter(messages.TxAck, attack_filter) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) # check that injection was performed @@ -168,9 +168,9 @@ def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with session, session.client as client, pytest.raises(TrezorFailure) as e: + with session, pytest.raises(TrezorFailure) as e: if session.model is not models.T1B1: - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 920b0bf48b..9c3c3f972b 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -611,11 +611,11 @@ def test_send_multisig_3_change(session: Session): request_finished(), ] - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -626,11 +626,11 @@ def test_send_multisig_3_change(session: Session): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -703,11 +703,11 @@ def test_send_multisig_4_change(session: Session): request_finished(), ] - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -718,11 +718,11 @@ def test_send_multisig_4_change(session: Session): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with session, session.client as client: + with session: session.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index 36b7cc31f0..e02833ed21 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -40,9 +40,9 @@ def test_message_long_legacy(session: Session): @pytest.mark.models("core") def test_message_long_core(session: Session): - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client, verify=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client, verify=True) + session.set_input_flow(IF.get()) ret = btc.verify_message( session, "Bitcoin", diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index d8ec9288eb..14e1d2d4f2 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -95,9 +95,11 @@ def test_cardano_get_address(session: Session, chunkify: bool, parameters, resul "cardano/get_public_key.derivations.json", ) def test_cardano_get_public_key(session: Session, parameters, result): - with session, session.client as client: - IF = InputFlowShowXpubQRCode(client, passphrase=bool(session.passphrase)) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode( + session.client, passphrase=bool(session.passphrase) + ) + session.set_input_flow(IF.get()) # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index 362a1793ce..ca4af67187 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -63,7 +63,7 @@ def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( session, parameters, - input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), + input_flow=lambda client: InputFlowConfirmAllWarnings(session.client).get(), ) assert response == _transform_expected_result(result) @@ -122,10 +122,10 @@ def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = else: device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) - with session.client as client: + with session: if input_flow is not None: - client.watch_layout() - client.set_input_flow(input_flow(client)) + session.client.watch_layout() + session.set_input_flow(input_flow(session.client)) return cardano.sign_tx( session=session, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index d99c54cb2b..124c50c41b 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -29,9 +29,9 @@ from ...input_flows import InputFlowShowXpubQRCode @pytest.mark.models("t2t1") @pytest.mark.setup_client(mnemonic=MNEMONIC12) def test_eos_get_public_key(session: Session): - with session.client as client: - IF = InputFlowShowXpubQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowXpubQRCode(session.client) + session.set_input_flow(IF.get()) public_key = get_public_key( session, parse_path("m/44h/194h/0h/0/0"), show_display=True ) diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 9cc3fd5704..052a09187d 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -123,9 +123,9 @@ def test_external_token(session: Session) -> None: def test_external_chain_without_token(session: Session) -> None: - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session: + if not session.client.debug.legacy_debug: + session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() @@ -145,9 +145,9 @@ def test_external_chain_token_ok(session: Session) -> None: def test_external_chain_token_mismatch(session: Session) -> None: - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session: + if not session.client.debug.legacy_debug: + session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) # when providing external defs, we explicitly allow, but not use, tokens # from other chains network = common.encode_network(chain_id=66666, slip44=60) diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index b57fcd6afd..a70085a590 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -37,9 +37,9 @@ def test_getaddress(session: Session, parameters, result): @pytest.mark.models("core", reason="No input flow for T1") @parametrize_using_common_fixtures("ethereum/getaddress.json") def test_getaddress_chunkify_details(session: Session, parameters, result): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( ethereum.get_address(session, address_n, show_display=True, chunkify=True) diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index dbb70c0810..ff4fbeec5c 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -97,10 +97,10 @@ DATA = { @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") def test_ethereum_sign_typed_data_show_more_button(session: Session): - with session.client as client: - client.watch_layout() - IF = InputFlowEIP712ShowMore(client) - client.set_input_flow(IF.get()) + with session: + session.client.watch_layout() + IF = InputFlowEIP712ShowMore(session.client) + session.set_input_flow(IF.get()) ethereum.sign_typed_data( session, parse_path("m/44h/60h/0h/0/0"), @@ -111,10 +111,10 @@ def test_ethereum_sign_typed_data_show_more_button(session: Session): @pytest.mark.models("core") def test_ethereum_sign_typed_data_cancel(session: Session): - with session.client as client, pytest.raises(exceptions.Cancelled): - client.watch_layout() - IF = InputFlowEIP712Cancel(client) - client.set_input_flow(IF.get()) + with session, pytest.raises(exceptions.Cancelled): + session.client.watch_layout() + IF = InputFlowEIP712Cancel(session.client) + session.set_input_flow(IF.get()) ethereum.sign_typed_data( session, parse_path("m/44h/60h/0h/0/0"), diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index c3ef56984c..cea066e5ef 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -36,9 +36,9 @@ def test_signmessage(session: Session, parameters, result): assert res.address == result["address"] assert res.signature.hex() == result["sig"] else: - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client) + session.set_input_flow(IF.get()) res = ethereum.sign_message( session, parse_path(parameters["path"]), parameters["msg"] ) @@ -57,9 +57,9 @@ def test_verify(session: Session, parameters, result): ) assert res is True else: - with session.client as client: - IF = InputFlowSignVerifyMessageLong(client, verify=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSignVerifyMessageLong(session.client, verify=True) + session.set_input_flow(IF.get()) res = ethereum.verify_message( session, parameters["address"], diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index f57e468a2d..092b5f9b95 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -73,10 +73,10 @@ def _do_test_signtx( input_flow=None, chunkify: bool = False, ): - with session.client as client: + with session: if input_flow: - client.watch_layout() - client.set_input_flow(input_flow) + session.client.watch_layout() + session.set_input_flow(input_flow) sig_v, sig_r, sig_s = ethereum.sign_tx( session, n=parse_path(parameters["path"]), @@ -151,9 +151,9 @@ def test_signtx_go_back_from_summary(session: Session): def test_signtx_eip1559( session: Session, chunkify: bool, parameters: dict, result: dict ): - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session: + if not session.client.debug.legacy_debug: + session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( session, n=parse_path(parameters["path"]), @@ -456,15 +456,15 @@ def test_signtx_data_pagination(session: Session, flow): data=bytes.fromhex(HEXDATA), ) - with session, session.client as client: - client.watch_layout() - client.set_input_flow(flow(client)) + with session: + session.client.watch_layout() + session.set_input_flow(flow(session.client)) _sign_tx_call() if flow is not input_flow_data_scroll_down: - with session, session.client as client, pytest.raises(exceptions.Cancelled): - client.watch_layout() - client.set_input_flow(flow(client, cancel=True)) + with session, pytest.raises(exceptions.Cancelled): + session.client.watch_layout() + session.set_input_flow(flow(session.client, cancel=True)) _sign_tx_call() diff --git a/tests/device_tests/misc/test_msg_enablelabeling.py b/tests/device_tests/misc/test_msg_enablelabeling.py index e1c0300191..7c5f7559df 100644 --- a/tests/device_tests/misc/test_msg_enablelabeling.py +++ b/tests/device_tests/misc/test_msg_enablelabeling.py @@ -33,8 +33,8 @@ def test_encrypt(client: Client): client.debug.press_yes() session = client.get_session() - with client, session: - client.set_input_flow(input_flow()) + with session: + session.set_input_flow(input_flow()) misc.encrypt_keyvalue( session, [], diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index 1a6d3ffc01..3317ad8ce9 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -56,9 +56,9 @@ def test_monero_getaddress(session: Session, path: str, expected_address: bytes) def test_monero_getaddress_chunkify_details( session: Session, path: str, expected_address: bytes ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = monero.get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 8841a52426..9574532533 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -51,10 +51,10 @@ def do_recover_legacy(session: Session, mnemonic: list[str]): def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): - with session.client as client: - client.watch_layout() - IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) - client.set_input_flow(IF.get()) + with session: + session.client.watch_layout() + IF = InputFlowBip39RecoveryDryRun(session.client, mnemonic, mismatch=mismatch) + session.set_input_flow(IF.get()) return device.recover(session, type=messages.RecoveryType.DryRun) @@ -87,10 +87,10 @@ def test_invalid_seed_t1(session: Session): @pytest.mark.models("core") def test_invalid_seed_core(session: Session): - with session, session.client as client: - client.watch_layout() + with session: + session.client.watch_layout() IF = InputFlowBip39RecoveryDryRunInvalid(session) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): return device.recover( session, diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index abca75bbee..58c6454988 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -28,9 +28,9 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(uninitialized=True) @pytest.mark.uninitialized_session def test_tt_pin_passphrase(session: Session): - with session.client as client: - IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "), pin="654") + session.set_input_flow(IF.get()) device.recover( session, pin_protection=True, @@ -49,9 +49,9 @@ def test_tt_pin_passphrase(session: Session): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.uninitialized_session def test_tt_nopin_nopassphrase(session: Session): - with session.client as client: - IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" ")) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=False, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index 3eb0c4d265..0747982857 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -48,9 +48,11 @@ VECTORS = ( def _test_secret( session: Session, shares: list[str], secret: str, click_info: bool = False ): - with session.client as client: - IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedRecovery( + session.client, shares, click_info=click_info + ) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=False, @@ -89,9 +91,9 @@ def test_extra_share_entered(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort(session: Session): - with session.client as client: - IF = InputFlowSlip39AdvancedRecoveryAbort(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedRecoveryAbort(session.client) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -100,11 +102,11 @@ def test_abort(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_noabort(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryNoAbort( - client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 + session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") session.refresh_features() assert session.features.initialized is True @@ -118,11 +120,11 @@ def test_same_share(session: Session): # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with session, session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( session, first_share, second_share ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @@ -134,10 +136,10 @@ def test_group_threshold_reached(session: Session): # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with session, session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( session, first_share, second_share ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 37b4a0264d..dbd0e7781c 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -40,11 +40,11 @@ EXTRA_GROUP_SHARE = [ @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) def test_2of3_dryrun(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39AdvancedRecoveryDryRun( - client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 + session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, @@ -57,13 +57,13 @@ def test_2of3_dryrun(session: Session): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with session.client as client, pytest.raises( + with session, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( - client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True + session.client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 1a20899279..ef258b820e 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -73,9 +73,9 @@ VECTORS = ( def test_secret( session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): - with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, shares) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecovery(session.client, shares) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") # Workflow successfully ended @@ -89,11 +89,11 @@ def test_secret( @pytest.mark.setup_client(uninitialized=True) def test_recover_with_pin_passphrase(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39BasicRecovery( - client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" + session.client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=True, @@ -109,9 +109,9 @@ def test_recover_with_pin_passphrase(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecoveryAbort(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryAbort(session.client) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -123,9 +123,9 @@ def test_abort(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort_on_number_of_words(session: Session): # on Caesar, test_abort actually aborts on the # of words selection - with session.client as client: - IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(session.client) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") assert session.features.initialized is False @@ -134,11 +134,11 @@ def test_abort_on_number_of_words(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort_between_shares(session: Session): - with session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( - client, MNEMONIC_SLIP39_BASIC_20_3of6 + session.client, MNEMONIC_SLIP39_BASIC_20_3of6 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -148,9 +148,11 @@ def test_abort_between_shares(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_noabort(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryNoAbort( + session.client, MNEMONIC_SLIP39_BASIC_20_3of6 + ) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") session.refresh_features() assert session.features.initialized is True @@ -158,9 +160,9 @@ def test_noabort(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_invalid_mnemonic_first_share(session: Session): - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -169,11 +171,11 @@ def test_invalid_mnemonic_first_share(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_invalid_mnemonic_second_share(session: Session): - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( session, MNEMONIC_SLIP39_BASIC_20_3of6 ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -184,9 +186,9 @@ def test_invalid_mnemonic_second_share(session: Session): @pytest.mark.parametrize("nth_word", range(3)) def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @@ -194,18 +196,18 @@ def test_wrong_nth_word(session: Session, nth_word: int): @pytest.mark.setup_client(uninitialized=True) def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoverySameShare(session, share) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @pytest.mark.setup_client(uninitialized=True) def test_1of1(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecovery(session.client, MNEMONIC_SLIP39_BASIC_20_1of1) + session.set_input_flow(IF.get()) device.recover( session, pin_protection=False, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index b9c4ca6daa..b4ffd53f19 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -38,9 +38,9 @@ INVALID_SHARES_20_2of3 = [ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) def test_2of3_dryrun(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecoveryDryRun(session.client, SHARES_20_2of3[1:3]) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, @@ -53,13 +53,13 @@ def test_2of3_dryrun(session: Session): @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) def test_2of3_invalid_seed_dryrun(session: Session): # test fails because of different seed on device - with session.client as client, pytest.raises( + with session, pytest.raises( TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( - client, INVALID_SHARES_20_2of3, mismatch=True + session.client, INVALID_SHARES_20_2of3, mismatch=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover( session, passphrase_protection=False, diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index 9710ee6201..1f9aa7e3c4 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -32,9 +32,9 @@ from ...input_flows import ( def backup_flow_bip39(session: Session) -> bytes: - with session.client as client: - IF = InputFlowBip39Backup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Backup(session.client) + session.set_input_flow(IF.get()) device.backup(session) assert IF.mnemonic is not None @@ -42,9 +42,9 @@ def backup_flow_bip39(session: Session) -> bytes: def backup_flow_slip39_basic(session: Session): - with session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) groups = shamir.decode_mnemonics(IF.mnemonics[:3]) @@ -53,9 +53,9 @@ def backup_flow_slip39_basic(session: Session): def backup_flow_slip39_advanced(session: Session): - with session.client as client: - IF = InputFlowSlip39AdvancedBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] @@ -116,9 +116,9 @@ def test_skip_backup_msg(session: Session, backup_type, backup_flow): def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): assert session.features.initialized is False - with session, session.client as client: - IF = InputFlowResetSkipBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowResetSkipBackup(session.client) + session.set_input_flow(IF.get()) device.setup( session, pin_protection=False, diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index 6e230f21aa..90c86e3d3a 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -36,9 +36,9 @@ pytestmark = pytest.mark.models("core") def reset_device(session: Session, strength: int): debug = session.client.debug - with session.client as client: - IF = InputFlowBip39ResetBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetBackup(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -92,9 +92,9 @@ def test_reset_device_pin(session: Session): debug = session.client.debug strength = 256 # 24 words - with session.client as client: - IF = InputFlowBip39ResetPIN(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetPIN(session.client) + session.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( @@ -130,9 +130,9 @@ def test_reset_device_pin(session: Session): def test_reset_entropy_check(session: Session): strength = 128 # 12 words - with session.client as client: - IF = InputFlowBip39ResetBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetBackup(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase path_xpubs = device.setup( @@ -147,7 +147,7 @@ def test_reset_entropy_check(session: Session): ) # Generate the mnemonic locally. - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -156,7 +156,7 @@ def test_reset_entropy_check(session: Session): assert IF.mnemonic == expected_mnemonic # Check that the device is properly initialized. - if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + if session.client.protocol_version is ProtocolVersion.PROTOCOL_V1: features = session.call_raw(messages.Initialize()) else: session.refresh_features() @@ -181,9 +181,9 @@ def test_reset_failed_check(session: Session): debug = session.client.debug strength = 256 # 24 words - with session.client as client: - IF = InputFlowBip39ResetFailedCheck(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetFailedCheck(session.client) + session.set_input_flow(IF.get()) # PIN, passphrase, display random device.setup( diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index e1ceacbb32..790cce5718 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -47,9 +47,9 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: - with session.client as client: - IF = InputFlowBip39ResetBackup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39ResetBackup(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -77,10 +77,10 @@ def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> s def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") - with session.client as client: - IF = InputFlowBip39Recovery(client, words) - client.set_input_flow(IF.get()) - client.watch_layout() + with session: + IF = InputFlowBip39Recovery(session.client, words) + session.set_input_flow(IF.get()) + session.client.watch_layout() device.recover(session, pin_protection=False, label="label") # Workflow successfully ended diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index 58d7569818..9fbec35dc6 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -68,9 +68,9 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128) -> list[str]: - with session.client as client: - IF = InputFlowSlip39AdvancedResetRecovery(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedResetRecovery(session.client, False) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -97,9 +97,9 @@ def reset(session: Session, strength: int = 128) -> list[str]: def recover(session: Session, shares: list[str]): - with session.client as client: - IF = InputFlowSlip39AdvancedRecovery(client, shares, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedRecovery(session.client, shares, False) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") # Workflow successfully ended diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 8e4e53fe47..4f43407680 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -58,9 +58,9 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128) -> list[str]: - with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicResetRecovery(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -87,9 +87,9 @@ def reset(session: Session, strength: int = 128) -> list[str]: def recover(session: Session, shares: t.Sequence[str]): - with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, shares) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicRecovery(session.client, shares) + session.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") # Workflow successfully ended diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 2d5c9edd4a..3cbda7dc06 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -34,10 +34,10 @@ def test_reset_device_slip39_advanced(client: Client): strength = 128 member_threshold = 3 - with client: + session = client.get_seedless_session() + with session: IF = InputFlowSlip39AdvancedResetRecovery(client, False) - client.set_input_flow(IF.get()) - session = client.get_seedless_session() + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( session, diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index dd25fc1342..64b8dd3a87 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -34,9 +34,9 @@ pytestmark = pytest.mark.models("core") def reset_device(session: Session, strength: int): member_threshold = 3 - with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicResetRecovery(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random device.setup( @@ -89,9 +89,9 @@ def test_reset_entropy_check(session: Session): strength = 128 # 20 words - with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicResetRecovery(session.client) + session.set_input_flow(IF.get()) # No PIN, no passphrase. path_xpubs = device.setup( diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 2a066926cd..f5247a4728 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -52,9 +52,9 @@ def test_ripple_get_address(session: Session, path: str, expected_address: str): def test_ripple_get_address_chunkify_details( session: Session, path: str, expected_address: str ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 708ccdd69f..b0aaefe361 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -47,9 +47,9 @@ pytestmark = [ def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) - with session.client as client: - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) actual_result = sign_tx( session, address_n=parse_path(parameters["address"]), diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 1d5c59e1f8..8d6dc70e76 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -122,9 +122,9 @@ def test_get_address(session: Session, parameters, result): @pytest.mark.models("core") @parametrize_using_common_fixtures("stellar/get_address.json") def test_get_address_chunkify_details(session: Session, parameters, result): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( session, address_n, show_display=True, chunkify=True diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index a310ff3841..a36487fbcb 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -38,8 +38,8 @@ def pin_request(session: Session): def set_autolock_delay(session: Session, delay): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ pin_request(session), @@ -61,8 +61,8 @@ def test_apply_auto_lock_delay(session: Session): get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([pin_request(session), messages.Address]) get_test_address(session) @@ -85,8 +85,8 @@ def test_apply_auto_lock_delay_valid(session: Session, seconds): def test_autolock_default_value(session: Session): assert session.features.auto_lock_delay_ms is None - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) device.apply_settings(session, label="pls unlock") session.refresh_features() assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 @@ -98,8 +98,8 @@ def test_autolock_default_value(session: Session): ) def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ pin_request(session), diff --git a/tests/device_tests/test_busy_state.py b/tests/device_tests/test_busy_state.py index 7de774aeaf..5818161ac8 100644 --- a/tests/device_tests/test_busy_state.py +++ b/tests/device_tests/test_busy_state.py @@ -48,8 +48,8 @@ def test_busy_state(session: Session): _assert_busy(session, True) assert session.features.unlocked is False - with session.client as client: - client.use_pin_sequence([PIN]) + with session: + session.client.use_pin_sequence([PIN]) btc.get_address( session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True ) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index a7fa64a454..06cb39cde7 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -40,9 +40,9 @@ def test_cancel_message_via_cancel(session: Session, message): yield session.cancel() - with session, session.client as client, pytest.raises(Cancelled): + with session, pytest.raises(Cancelled): session.set_expected_responses([m.ButtonRequest(), m.Failure()]) - client.set_input_flow(input_flow) + session.set_input_flow(input_flow) session.call(message) diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 4123b5e1b4..b844ac0371 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -47,12 +47,12 @@ def test_pin(session: Session): ) assert isinstance(resp, messages.PinMatrixRequest) - with session.client as client: - state = client.debug.state() + with session: + state = session.client.debug.state() assert state.pin == "1234" assert state.matrix != "" - pin_encoded = client.debug.encode_pin("1234") + pin_encoded = session.client.debug.encode_pin("1234") resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) assert isinstance(resp, messages.PassphraseRequest) diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index 0fe6e27595..dd1d8ad744 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -79,9 +79,9 @@ def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> if session.model in (models.T2T1, models.T3T1): right_button = "-" - with session, session.client as client: - client.watch_layout(True) - client.set_input_flow(ping_input_flow(session, title, right_button)) + with session: + session.client.watch_layout(True) + session.set_input_flow(ping_input_flow(session, title, right_button)) ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) assert ping == messages.Success(message="ahoj!") @@ -274,8 +274,8 @@ def test_reject_update(session: Session): yield session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), session, session.client as client: - client.set_input_flow(input_flow_reject) + with pytest.raises(exceptions.Cancelled), session: + session.set_input_flow(input_flow_reject) device.change_language(session, language_data) assert session.features.language == "en-US" diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 40c18d2cab..5fc3684fbb 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -345,12 +345,12 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest, messages.ButtonRequest, messages.Address] ) - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) get_bad_address() with session: @@ -371,13 +371,13 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest, messages.ButtonRequest, messages.Address] ) if session.model is not models.T1B1: - IF = InputFlowConfirmAllWarnings(client) - client.set_input_flow(IF.get()) + IF = InputFlowConfirmAllWarnings(session.client) + session.set_input_flow(IF.get()) get_bad_address() @@ -412,8 +412,8 @@ def test_experimental_features(session: Session): # relock and try again session.lock() - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([messages.ButtonRequest, messages.Nonce]) experimental_call() diff --git a/tests/device_tests/test_msg_backup_device.py b/tests/device_tests/test_msg_backup_device.py index 56d96ce14a..c7a8156b50 100644 --- a/tests/device_tests/test_msg_backup_device.py +++ b/tests/device_tests/test_msg_backup_device.py @@ -44,9 +44,9 @@ from ..input_flows import ( def test_backup_bip39(session: Session): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowBip39Backup(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowBip39Backup(session.client) + session.set_input_flow(IF.get()) device.backup(session) assert IF.mnemonic == MNEMONIC12 @@ -71,9 +71,9 @@ def test_backup_slip39_basic(session: Session, click_info: bool): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowSlip39BasicBackup(client, click_info) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, click_info) + session.set_input_flow(IF.get()) device.backup(session) session.refresh_features() @@ -95,11 +95,12 @@ def test_backup_slip39_basic(session: Session, click_info: bool): def test_backup_slip39_single(session: Session): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: + with session: IF = InputFlowBip39Backup( - client, confirm_success=(client.layout_type is not LayoutType.Delizia) + session.client, + confirm_success=(session.client.layout_type is not LayoutType.Delizia), ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.backup(session) assert session.features.initialized is True @@ -126,9 +127,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowSlip39AdvancedBackup(client, click_info) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39AdvancedBackup(session.client, click_info) + session.set_input_flow(IF.get()) device.backup(session) session.refresh_features() @@ -157,9 +158,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool): def test_backup_slip39_custom(session: Session, share_threshold, share_count): assert session.features.backup_availability == messages.BackupAvailability.Required - with session.client as client: - IF = InputFlowSlip39CustomBackup(client, share_count) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39CustomBackup(session.client, share_count) + session.set_input_flow(IF.get()) device.backup( session, group_threshold=1, groups=[(share_threshold, share_count)] ) diff --git a/tests/device_tests/test_msg_change_wipe_code_t1.py b/tests/device_tests/test_msg_change_wipe_code_t1.py index 8de1439787..c66c386a7a 100644 --- a/tests/device_tests/test_msg_change_wipe_code_t1.py +++ b/tests/device_tests/test_msg_change_wipe_code_t1.py @@ -34,7 +34,7 @@ pytestmark = pytest.mark.models("legacy") def _set_wipe_code(session: Session, pin, wipe_code): # Set/change wipe code. - with session.client as client, session: + with session: if session.features.pin_protection: pins = [pin, wipe_code, wipe_code] pin_matrices = [ @@ -49,7 +49,7 @@ def _set_wipe_code(session: Session, pin, wipe_code): messages.PinMatrixRequest(type=PinType.WipeCodeSecond), ] - client.use_pin_sequence(pins) + session.client.use_pin_sequence(pins) session.set_expected_responses( [messages.ButtonRequest()] + pin_matrices + [messages.Success] ) @@ -58,8 +58,8 @@ def _set_wipe_code(session: Session, pin, wipe_code): def _change_pin(session: Session, old_pin, new_pin): assert session.features.pin_protection is True - with session.client as client: - client.use_pin_sequence([old_pin, new_pin, new_pin]) + with session: + session.client.use_pin_sequence([old_pin, new_pin, new_pin]) try: return device.change_pin(session) except exceptions.TrezorFailure as f: @@ -96,8 +96,8 @@ def test_set_remove_wipe_code(session: Session): _check_wipe_code(session, PIN4, WIPE_CODE6) # Test remove wipe code. - with session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) device.change_wipe_code(session, remove=True) # Check that there's no wipe code protection now. @@ -111,8 +111,8 @@ def test_set_wipe_code_mismatch(session: Session): assert session.features.wipe_code_protection is False # Let's set a new wipe code. - with session.client as client, session: - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6]) + with session: + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6]) session.set_expected_responses( [ messages.ButtonRequest(), @@ -125,8 +125,8 @@ def test_set_wipe_code_mismatch(session: Session): device.change_wipe_code(session) # Check that there is no wipe code protection. - client.refresh_features() - assert client.features.wipe_code_protection is False + session.client.refresh_features() + assert session.client.features.wipe_code_protection is False @pytest.mark.setup_client(pin=PIN4) @@ -135,8 +135,8 @@ def test_set_wipe_code_to_pin(session: Session): assert session.features.wipe_code_protection is None # Let's try setting the wipe code to the curent PIN value. - with session.client as client, session: - client.use_pin_sequence([PIN4, PIN4]) + with session: + session.client.use_pin_sequence([PIN4, PIN4]) session.set_expected_responses( [ messages.ButtonRequest(), @@ -149,8 +149,8 @@ def test_set_wipe_code_to_pin(session: Session): device.change_wipe_code(session) # Check that there is no wipe code protection. - client.refresh_features() - assert client.features.wipe_code_protection is False + session.client.refresh_features() + assert session.client.features.wipe_code_protection is False def test_set_pin_to_wipe_code(session: Session): @@ -159,8 +159,8 @@ def test_set_pin_to_wipe_code(session: Session): _set_wipe_code(session, None, WIPE_CODE4) # Try to set the PIN to the current wipe code value. - with session.client as client, session: - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) + with session: + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) session.set_expected_responses( [ messages.ButtonRequest(), diff --git a/tests/device_tests/test_msg_change_wipe_code_t2.py b/tests/device_tests/test_msg_change_wipe_code_t2.py index 9142b6dc95..92e569aafc 100644 --- a/tests/device_tests/test_msg_change_wipe_code_t2.py +++ b/tests/device_tests/test_msg_change_wipe_code_t2.py @@ -37,8 +37,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str): assert session.features.wipe_code_protection is True # Try to change the PIN to the current wipe code value. The operation should fail. - with session, session.client as client, pytest.raises(TrezorFailure): - client.use_pin_sequence([pin, wipe_code, wipe_code]) + with session, pytest.raises(TrezorFailure): + session.client.use_pin_sequence([pin, wipe_code, wipe_code]) if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: @@ -51,8 +51,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str): def _ensure_unlocked(session: Session, pin: str): - with session, session.client as client: - client.use_pin_sequence([pin]) + with session: + session.client.use_pin_sequence([pin]) btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH) session.refresh_features() @@ -71,11 +71,11 @@ def test_set_remove_wipe_code(session: Session): else: br_count = 5 - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success] ) - client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX]) + session.client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX]) device.change_wipe_code(session) # session.init_device() @@ -83,11 +83,11 @@ def test_set_remove_wipe_code(session: Session): _check_wipe_code(session, PIN4, WIPE_CODE_MAX) # Test change wipe code. - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success] ) - client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6]) + session.client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6]) device.change_wipe_code(session) # session.init_device() @@ -95,11 +95,11 @@ def test_set_remove_wipe_code(session: Session): _check_wipe_code(session, PIN4, WIPE_CODE6) # Test remove wipe code. - with session, session.client as client: + with session: session.set_expected_responses( [messages.ButtonRequest()] * 3 + [messages.Success] ) - client.use_pin_sequence([PIN4]) + session.client.use_pin_sequence([PIN4]) device.change_wipe_code(session, remove=True) # session.init_device() @@ -107,9 +107,11 @@ def test_set_remove_wipe_code(session: Session): def test_set_wipe_code_mismatch(session: Session): - with session, session.client as client, pytest.raises(TrezorFailure): - IF = InputFlowNewCodeMismatch(client, WIPE_CODE4, WIPE_CODE6, what="wipe_code") - client.set_input_flow(IF.get()) + with session, pytest.raises(TrezorFailure): + IF = InputFlowNewCodeMismatch( + session.client, WIPE_CODE4, WIPE_CODE6, what="wipe_code" + ) + session.set_input_flow(IF.get()) device.change_wipe_code(session) @@ -122,15 +124,15 @@ def test_set_wipe_code_mismatch(session: Session): def test_set_wipe_code_to_pin(session: Session): _ensure_unlocked(session, PIN4) - with session, session.client as client: - if client.layout_type is LayoutType.Caesar: + with session: + if session.client.layout_type is LayoutType.Caesar: br_count = 8 else: br_count = 7 session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success], ) - client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4]) + session.client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4]) device.change_wipe_code(session) # session.init_device() @@ -140,20 +142,20 @@ def test_set_wipe_code_to_pin(session: Session): def test_set_pin_to_wipe_code(session: Session): # Set wipe code. - with session, session.client as client: - if client.layout_type is LayoutType.Caesar: + with session: + if session.client.layout_type is LayoutType.Caesar: br_count = 5 else: br_count = 4 session.set_expected_responses( [messages.ButtonRequest()] * br_count + [messages.Success] ) - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) device.change_wipe_code(session) # Try to set the PIN to the current wipe code value. - with session, session.client as client, pytest.raises(TrezorFailure): - if client.layout_type is LayoutType.Caesar: + with session, pytest.raises(TrezorFailure): + if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: br_count = 4 @@ -161,5 +163,5 @@ def test_set_pin_to_wipe_code(session: Session): [messages.ButtonRequest()] * br_count + [messages.Failure(code=messages.FailureType.PinInvalid)] ) - client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) + session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) device.change_pin(session) diff --git a/tests/device_tests/test_msg_changepin_t1.py b/tests/device_tests/test_msg_changepin_t1.py index 3404e44a36..0ed0013502 100644 --- a/tests/device_tests/test_msg_changepin_t1.py +++ b/tests/device_tests/test_msg_changepin_t1.py @@ -33,8 +33,8 @@ pytestmark = pytest.mark.models("legacy") def _check_pin(session: Session, pin): session.lock() - with session, session.client as client: - client.use_pin_sequence([pin]) + with session: + session.client.use_pin_sequence([pin]) session.set_expected_responses([messages.PinMatrixRequest, messages.Address]) get_test_address(session) @@ -53,8 +53,8 @@ def test_set_pin(session: Session): _check_no_pin(session) # Let's set new PIN - with session, session.client as client: - client.use_pin_sequence([PIN_MAX, PIN_MAX]) + with session: + session.client.use_pin_sequence([PIN_MAX, PIN_MAX]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -78,8 +78,8 @@ def test_change_pin(session: Session): _check_pin(session, PIN4) # Let's change PIN - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) + with session: + session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -104,8 +104,8 @@ def test_remove_pin(session: Session): _check_pin(session, PIN4) # Let's remove PIN - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -126,11 +126,9 @@ def test_set_mismatch(session: Session): _check_no_pin(session) # Let's set new PIN - with session, session.client as client, pytest.raises( - TrezorFailure, match="PIN mismatch" - ): + with session, pytest.raises(TrezorFailure, match="PIN mismatch"): # use different PINs for first and second attempt. This will fail. - client.use_pin_sequence([PIN4, PIN_MAX]) + session.client.use_pin_sequence([PIN4, PIN_MAX]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), @@ -152,10 +150,8 @@ def test_change_mismatch(session: Session): assert session.features.pin_protection is True # Let's set new PIN - with session, session.client as client, pytest.raises( - TrezorFailure, match="PIN mismatch" - ): - client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"]) + with session, pytest.raises(TrezorFailure, match="PIN mismatch"): + session.client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), diff --git a/tests/device_tests/test_msg_changepin_t2.py b/tests/device_tests/test_msg_changepin_t2.py index 7c6d9ba72e..d740cb0ae4 100644 --- a/tests/device_tests/test_msg_changepin_t2.py +++ b/tests/device_tests/test_msg_changepin_t2.py @@ -37,9 +37,9 @@ pytestmark = pytest.mark.models("core") def _check_pin(session: Session, pin: str): - with session, session.client as client: - client.ui.__init__(client.debug) - client.use_pin_sequence([pin, pin, pin, pin, pin, pin]) + with session: + session.client.ui.__init__(session.client.debug) + session.client.use_pin_sequence([pin, pin, pin, pin, pin, pin]) session.lock() assert session.features.pin_protection is True assert session.features.unlocked is False @@ -63,12 +63,12 @@ def test_set_pin(session: Session): _check_no_pin(session) # Let's set new PIN - with session, session.client as client: - if client.layout_type is LayoutType.Caesar: + with session: + if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: br_count = 4 - client.use_pin_sequence([PIN_MAX, PIN_MAX]) + session.client.use_pin_sequence([PIN_MAX, PIN_MAX]) session.set_expected_responses( [messages.ButtonRequest] * br_count + [messages.Success] ) @@ -86,9 +86,9 @@ def test_change_pin(session: Session): _check_pin(session, PIN4) # Let's change PIN - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) - if client.layout_type is LayoutType.Caesar: + with session: + session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) + if session.client.layout_type is LayoutType.Caesar: br_count = 6 else: br_count = 5 @@ -113,8 +113,8 @@ def test_remove_pin(session: Session): _check_pin(session, PIN4) # Let's remove PIN - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [messages.ButtonRequest] * 3 + [messages.Success] ) @@ -132,9 +132,9 @@ def test_set_failed(session: Session): # Check that there's no PIN protection _check_no_pin(session) - with session, session.client as client, pytest.raises(TrezorFailure): - IF = InputFlowNewCodeMismatch(client, PIN4, PIN60, what="pin") - client.set_input_flow(IF.get()) + with session, pytest.raises(TrezorFailure): + IF = InputFlowNewCodeMismatch(session.client, PIN4, PIN60, what="pin") + session.set_input_flow(IF.get()) device.change_pin(session) @@ -151,9 +151,9 @@ def test_change_failed(session: Session): # Check current PIN value _check_pin(session, PIN4) - with session, session.client as client, pytest.raises(Cancelled): + with session, pytest.raises(Cancelled): IF = InputFlowCodeChangeFail(session, PIN4, "457891", "381847") - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.change_pin(session) @@ -170,9 +170,9 @@ def test_change_invalid_current(session: Session): # Check current PIN value _check_pin(session, PIN4) - with session, session.client as client, pytest.raises(TrezorFailure): - IF = InputFlowWrongPIN(client, PIN60) - client.set_input_flow(IF.get()) + with session, pytest.raises(TrezorFailure): + IF = InputFlowWrongPIN(session.client, PIN60) + session.set_input_flow(IF.get()) device.change_pin(session) @@ -200,7 +200,7 @@ def test_pin_menu_cancel_setup(session: Session): # tap to confirm debug.click(debug.screen_buttons.tap_to_confirm()) - with session, session.client as client, pytest.raises(Cancelled): - client.set_input_flow(cancel_pin_setup_input_flow) + with session, pytest.raises(Cancelled): + session.set_input_flow(cancel_pin_setup_input_flow) session.call(messages.ChangePin()) _check_no_pin(session) diff --git a/tests/device_tests/test_msg_wipedevice.py b/tests/device_tests/test_msg_wipedevice.py index d46be75e84..7bec80797d 100644 --- a/tests/device_tests/test_msg_wipedevice.py +++ b/tests/device_tests/test_msg_wipedevice.py @@ -45,9 +45,8 @@ def test_wipe_device(client: Client): @pytest.mark.setup_client(pin=PIN4) def test_autolock_not_retained(session: Session): client = session.client - with client: - client.use_pin_sequence([PIN4]) - device.apply_settings(session, auto_lock_delay_ms=10_000) + client.use_pin_sequence([PIN4]) + device.apply_settings(session, auto_lock_delay_ms=10_000) assert session.features.auto_lock_delay_ms == 10_000 @@ -57,21 +56,20 @@ def test_autolock_not_retained(session: Session): assert client.features.auto_lock_delay_ms > 10_000 - with client: - client.use_pin_sequence([PIN4, PIN4]) - device.setup( - session, - skip_backup=True, - pin_protection=True, - passphrase_protection=False, - entropy_check_count=0, - backup_type=messages.BackupType.Bip39, - ) + client.use_pin_sequence([PIN4, PIN4]) + device.setup( + session, + skip_backup=True, + pin_protection=True, + passphrase_protection=False, + entropy_check_count=0, + backup_type=messages.BackupType.Bip39, + ) time.sleep(10.5) session = client.get_session() - with session, client: + with session: # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked session.set_expected_responses([messages.Address]) get_test_address(session) diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index c911dfee50..b5b7981d99 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -39,8 +39,8 @@ def test_no_protection(session: Session): def test_correct_pin(session: Session): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) # Expected responses differ between T1 and TT is_t1 = session.model is models.T1B1 session.set_expected_responses( @@ -65,9 +65,9 @@ def test_incorrect_pin_t1(session: Session): @pytest.mark.models("core") def test_incorrect_pin_t2(session: Session): - with session, session.client as client: + with session: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt - client.use_pin_sequence([BAD_PIN, PIN4]) + session.client.use_pin_sequence([BAD_PIN, PIN4]) session.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), @@ -82,15 +82,15 @@ def test_incorrect_pin_t2(session: Session): def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with session, session.client as client, pytest.raises(PinException): - client.use_pin_sequence([BAD_PIN]) + with session, pytest.raises(PinException): + session.client.use_pin_sequence([BAD_PIN]) get_test_address(session) check_pin_backoff_time(attempt, start) @pytest.mark.models("core") def test_exponential_backoff_t2(session: Session): - with session.client as client: - IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowPINBackoff(session.client, BAD_PIN, PIN4) + session.set_input_flow(IF.get()) get_test_address(session) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index 0615e41508..083e91e93b 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -56,12 +56,12 @@ def _assert_protection( session: Session, pin: bool = True, passphrase: bool = True ) -> Session: """Make sure PIN and passphrase protection have expected values""" - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.ensure_unlocked() - client.refresh_features() - assert client.features.pin_protection is pin - assert client.features.passphrase_protection is passphrase + session.client.refresh_features() + assert session.client.features.pin_protection is pin + assert session.client.features.passphrase_protection is passphrase session.lock() # session.end() if session.protocol_version == ProtocolVersion.PROTOCOL_V1: @@ -70,8 +70,8 @@ def _assert_protection( def test_initialize(session: Session): - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.ensure_unlocked() session = _assert_protection(session) with session: @@ -86,8 +86,8 @@ def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) device.apply_settings(session, use_passphrase=passphrase) session.lock() @@ -108,8 +108,8 @@ def test_passphrase_reporting(session: Session, passphrase): def test_apply_settings(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -124,8 +124,8 @@ def test_apply_settings(session: Session): @pytest.mark.models("legacy") def test_change_pin_t1(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN4, PIN4]) + with session: + session.client.use_pin_sequence([PIN4, PIN4, PIN4]) session.set_expected_responses( [ messages.ButtonRequest, @@ -141,8 +141,8 @@ def test_change_pin_t1(session: Session): @pytest.mark.models("core") def test_change_pin_t2(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) + with session: + session.client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -172,8 +172,8 @@ def test_ping(session: Session): def test_get_entropy(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -187,8 +187,8 @@ def test_get_entropy(session: Session): def test_get_public_key(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] if session.protocol_version == ProtocolVersion.PROTOCOL_V1: @@ -202,8 +202,8 @@ def test_get_public_key(session: Session): def test_get_address(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) @@ -221,8 +221,8 @@ def test_wipe_device(session: Session): device.wipe(session) client = session.client.get_new_client() session = client.get_seedless_session() - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([messages.Features]) session.call(messages.GetFeatures()) @@ -301,8 +301,8 @@ def test_recovery_device(session: Session): def test_sign_message(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] @@ -350,8 +350,8 @@ def test_verify_message_t1(session: Session): @pytest.mark.models("core") def test_verify_message_t2(session: Session): session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses( [ _pin_request(session), @@ -389,8 +389,8 @@ def test_signtx(session: Session): ) session = _assert_protection(session) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) @@ -430,8 +430,8 @@ def test_unlocked(session: Session): session = _assert_protection(session, passphrase=False) - with session, session.client as client: - client.use_pin_sequence([PIN4]) + with session: + session.client.use_pin_sequence([PIN4]) session.set_expected_responses([_pin_request(session), messages.Address]) get_test_address(session) diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 601c898fbb..2eb28cd32a 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -39,9 +39,9 @@ def test_repeated_backup(session: Session): # initial device backup mnemonics = [] - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics @@ -56,11 +56,11 @@ def test_repeated_backup(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, mnemonics[:3], unlock_repeated_backup=True + session.client, mnemonics[:3], unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability @@ -69,9 +69,9 @@ def test_repeated_backup(session: Session): assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False, repeated=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True) + session.set_input_flow(IF.get()) device.backup(session) # the backup feature is locked again... @@ -92,11 +92,11 @@ def test_repeated_backup_upgrade_single(session: Session): assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True + session.client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability @@ -105,9 +105,9 @@ def test_repeated_backup_upgrade_single(session: Session): assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False, repeated=True) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True) + session.set_input_flow(IF.get()) device.backup(session) # backup type was upgraded: @@ -128,9 +128,9 @@ def test_repeated_backup_cancel(session: Session): # initial device backup mnemonics = [] - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics @@ -145,11 +145,11 @@ def test_repeated_backup_cancel(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, mnemonics[:3], unlock_repeated_backup=True + session.client, mnemonics[:3], unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability @@ -183,9 +183,9 @@ def test_repeated_backup_send_disallowed_message(session: Session): # initial device backup mnemonics = [] - with session, session.client as client: - IF = InputFlowSlip39BasicBackup(client, False) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowSlip39BasicBackup(session.client, False) + session.set_input_flow(IF.get()) device.backup(session) mnemonics = IF.mnemonics @@ -200,11 +200,11 @@ def test_repeated_backup_send_disallowed_message(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with session: IF = InputFlowSlip39BasicRecoveryDryRun( - client, mnemonics[:3], unlock_repeated_backup=True + session.client, mnemonics[:3], unlock_repeated_backup=True ) - client.set_input_flow(IF.get()) + session.set_input_flow(IF.get()) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) assert ( session.features.backup_availability diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index 8d5c45b81f..2faaf71a43 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -45,8 +45,8 @@ def test_sd_no_format(session: Session): yield # format SD card debug.press_no() - with session, session.client as client, pytest.raises(TrezorFailure) as e: - client.set_input_flow(input_flow) + with session, pytest.raises(TrezorFailure) as e: + session.set_input_flow(input_flow) device.sd_protect(session, Op.ENABLE) assert e.value.code == messages.FailureType.ProcessError @@ -76,9 +76,9 @@ def test_sd_protect_unlock(session: Session): assert TR.sd_card__enabled in layout().text_content() debug.press_yes() - with session, session.client as client: - client.watch_layout() - client.set_input_flow(input_flow_enable_sd_protect) + with session: + session.client.watch_layout() + session.set_input_flow(input_flow_enable_sd_protect) device.sd_protect(session, Op.ENABLE) def input_flow_change_pin(): @@ -102,9 +102,9 @@ def test_sd_protect_unlock(session: Session): assert TR.pin__changed in layout().text_content() debug.press_yes() - with session, session.client as client: - client.watch_layout() - client.set_input_flow(input_flow_change_pin) + with session: + session.client.watch_layout() + session.set_input_flow(input_flow_change_pin) device.change_pin(session) debug.erase_sd_card(format=False) @@ -125,9 +125,9 @@ def test_sd_protect_unlock(session: Session): ) debug.press_no() # close - with session, session.client as client, pytest.raises(TrezorFailure) as e: - client.watch_layout() - client.set_input_flow(input_flow_change_pin_format) + with session, pytest.raises(TrezorFailure) as e: + session.client.watch_layout() + session.set_input_flow(input_flow_change_pin_format) device.change_pin(session) assert e.value.code == messages.FailureType.ProcessError diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 5e8a850b5f..19cbfb2d95 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -41,7 +41,7 @@ def test_clear_session(client: Client): cached_responses = [messages.PublicKey] session = client.get_session() session.lock() - with client, session: + with session: client.use_pin_sequence([PIN4]) session.set_expected_responses(init_responses + cached_responses) assert get_public_node(session, ADDRESS_N).xpub == XPUB @@ -57,7 +57,7 @@ def test_clear_session(client: Client): session = client.get_session() # session cache is cleared - with client, session: + with session: client.use_pin_sequence([PIN4]) session.set_expected_responses(init_responses + cached_responses) assert get_public_node(session, ADDRESS_N).xpub == XPUB @@ -76,7 +76,7 @@ def test_end_session(client: Client): assert session.id is not None # get_address will succeed - with session: + with session as session: session.set_expected_responses([messages.Address]) get_test_address(session) @@ -135,7 +135,7 @@ def test_end_session_only_current(client: Client): @pytest.mark.setup_client(passphrase=True) def test_session_recycling(client: Client): session = client.get_session(passphrase="TREZOR") - with client, session: + with session: session.set_expected_responses( [ messages.PassphraseRequest, @@ -152,7 +152,7 @@ def test_session_recycling(client: Client): session_x.end() # it should still be possible to resume the original session - with client, session: + with session: # passphrase should still be cached session.set_expected_responses([messages.Address] * 3) client.resume_session(session) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 943623aa0c..f710e7162e 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -396,7 +396,7 @@ def test_passphrase_length(client: Client): def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails session = client.get_seedless_session() - with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: + with pytest.raises(TrezorFailure, match="Safety checks are strict"): device.apply_settings(session, hide_passphrase_from_host=True) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) @@ -406,7 +406,7 @@ def test_hide_passphrase_from_host(client: Client): passphrase = "abc" session = client.get_session(passphrase=passphrase) - with client, session: + with session: def input_flow(): yield @@ -421,8 +421,8 @@ def test_hide_passphrase_from_host(client: Client): else: raise KeyError - client.watch_layout() - client.set_input_flow(input_flow) + session.client.watch_layout() + session.set_input_flow(input_flow) session.set_expected_responses( [ messages.PassphraseRequest, @@ -440,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client): # Starting new session, otherwise the passphrase would be cached session = client.get_session(passphrase=passphrase) - with client, session: + with session: def input_flow(): yield @@ -455,8 +455,8 @@ def test_hide_passphrase_from_host(client: Client): assert passphrase in client.debug.read_layout().text_content() client.debug.press_yes() - client.watch_layout() - client.set_input_flow(input_flow) + session.client.watch_layout() + session.set_input_flow(input_flow) session.set_expected_responses( [ messages.PassphraseRequest, diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 9f35118370..4bac751148 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -44,9 +44,9 @@ def test_tezos_get_address(session: Session, path: str, expected_address: str): def test_tezos_get_address_chunkify_details( session: Session, path: str, expected_address: str ): - with session.client as client: - IF = InputFlowShowAddressQRCode(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowShowAddressQRCode(session.client) + session.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True ) diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 7016e2f5f8..4550c01077 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -31,9 +31,9 @@ RK_CAPACITY = 100 @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) def test_add_remove(session: Session): - with session, session.client as client: - IF = InputFlowFidoConfirm(client) - client.set_input_flow(IF.get()) + with session: + IF = InputFlowFidoConfirm(session.client) + session.set_input_flow(IF.get()) # Remove index 0 should fail. with pytest.raises(TrezorFailure): diff --git a/tests/input_flows.py b/tests/input_flows.py index e222ca1030..5a1bd9bf10 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -50,16 +50,18 @@ class InputFlowBase: # There could be one common input flow for all models if hasattr(self, "input_flow_common"): - return getattr(self, "input_flow_common") + flow = getattr(self, "input_flow_common") elif self.client.layout_type is LayoutType.Bolt: - return self.input_flow_bolt + flow = self.input_flow_bolt elif self.client.layout_type is LayoutType.Caesar: - return self.input_flow_caesar + flow = self.input_flow_caesar elif self.client.layout_type is LayoutType.Delizia: - return self.input_flow_delizia + flow = self.input_flow_delizia else: raise ValueError("Unknown model") + return flow + def input_flow_bolt(self) -> BRGeneratorType: """Special for TT""" raise NotImplementedError @@ -371,7 +373,7 @@ class InputFlowSignMessageInfo(InputFlowBase): self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1]) # address mismatch? yes! self.debug.swipe_up() - yield + yield # ? class InputFlowShowAddressQRCode(InputFlowBase): diff --git a/tests/persistence_tests/test_wipe_code.py b/tests/persistence_tests/test_wipe_code.py index 8dee771a6a..cd5c1bc2e3 100644 --- a/tests/persistence_tests/test_wipe_code.py +++ b/tests/persistence_tests/test_wipe_code.py @@ -11,33 +11,37 @@ WIPE_CODE = "9876" def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client.get_seedless_session()) + session = client.get_seedless_session() + device.wipe(session) client = client.get_new_client() + session = client.get_seedless_session() debuglink.load_device( - client.get_seedless_session(), + session, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE", ) - with client: + with session: client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) device.change_wipe_code(client.get_seedless_session()) def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client.get_seedless_session()) + session = client.get_seedless_session() + device.wipe(session) client = client.get_new_client() + session = client.get_seedless_session() debuglink.load_device( - client.get_seedless_session(), + session, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE", ) - with client: + with session: client.use_pin_sequence([pin, wipe_code, wipe_code]) device.change_wipe_code(client.get_seedless_session()) diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 79951ddafe..64606eb9b6 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -96,7 +96,7 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None: assert client.features.initialized assert client.features.label == LABEL session = client.get_session() - with client, session: + with session: client.use_pin_sequence([PIN]) assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS @@ -395,10 +395,11 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): # Create a backup of the encrypted master secret. assert emu.client.features.backup_availability == BackupAvailability.Required - with emu.client: + session = emu.client.get_session() + with session: IF = InputFlowSlip39BasicBackup(emu.client, False) - emu.client.set_input_flow(IF.get()) - device.backup(emu.client.get_session()) + session.set_input_flow(IF.get()) + device.backup(session) assert ( emu.client.features.backup_availability == BackupAvailability.NotAvailable )