From 9972643779a4b894f2f90890766f8b3890bea6b6 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Thu, 6 Mar 2025 00:35:51 +0100 Subject: [PATCH] refactor(tests): move set_input_flow to SessionDebugWrapper context manager [no changelog] --- python/src/trezorlib/debuglink.py | 582 +++++++----------- tests/burn_tests/burntest_t2.py | 4 +- tests/click_tests/test_autolock.py | 16 +- tests/device_handler.py | 4 + .../device_tests/binance/test_get_address.py | 2 +- .../binance/test_get_public_key.py | 2 +- .../bitcoin/test_authorize_coinjoin.py | 34 +- tests/device_tests/bitcoin/test_bcash.py | 34 +- tests/device_tests/bitcoin/test_bgold.py | 42 +- tests/device_tests/bitcoin/test_dash.py | 8 +- tests/device_tests/bitcoin/test_decred.py | 20 +- .../device_tests/bitcoin/test_descriptors.py | 6 +- tests/device_tests/bitcoin/test_getaddress.py | 28 +- .../bitcoin/test_getaddress_segwit.py | 2 +- .../bitcoin/test_getaddress_show.py | 16 +- .../device_tests/bitcoin/test_getpublickey.py | 4 +- tests/device_tests/bitcoin/test_komodo.py | 8 +- tests/device_tests/bitcoin/test_multisig.py | 17 +- .../bitcoin/test_multisig_change.py | 52 +- .../bitcoin/test_nonstandard_paths.py | 10 +- tests/device_tests/bitcoin/test_op_return.py | 12 +- .../device_tests/bitcoin/test_signmessage.py | 16 +- tests/device_tests/bitcoin/test_signtx.py | 89 +-- .../bitcoin/test_signtx_amount_unit.py | 4 +- .../bitcoin/test_signtx_external.py | 42 +- .../bitcoin/test_signtx_invalid_path.py | 12 +- .../bitcoin/test_signtx_mixed_inputs.py | 8 +- .../bitcoin/test_signtx_payreq.py | 2 +- .../bitcoin/test_signtx_prevhash.py | 8 +- .../bitcoin/test_signtx_replacement.py | 30 +- .../bitcoin/test_signtx_segwit.py | 38 +- .../bitcoin/test_signtx_segwit_native.py | 66 +- .../bitcoin/test_signtx_taproot.py | 22 +- .../bitcoin/test_verifymessage.py | 2 +- tests/device_tests/bitcoin/test_zcash.py | 12 +- .../cardano/test_address_public_key.py | 4 +- tests/device_tests/cardano/test_sign_tx.py | 6 +- tests/device_tests/eos/test_get_public_key.py | 2 +- tests/device_tests/eos/test_signtx.py | 32 +- .../device_tests/ethereum/test_definitions.py | 12 +- .../device_tests/ethereum/test_getaddress.py | 2 +- .../ethereum/test_sign_typed_data.py | 8 +- .../ethereum/test_sign_verify_message.py | 4 +- tests/device_tests/ethereum/test_signtx.py | 26 +- .../misc/test_msg_enablelabeling.py | 2 +- .../device_tests/misc/test_msg_getentropy.py | 4 +- tests/device_tests/monero/test_getaddress.py | 2 +- tests/device_tests/nem/test_signtx_others.py | 2 +- .../device_tests/nem/test_signtx_transfers.py | 8 +- .../test_recovery_bip39_dryrun.py | 8 +- .../reset_recovery/test_recovery_bip39_t2.py | 4 +- .../test_recovery_slip39_advanced.py | 12 +- .../test_recovery_slip39_advanced_dryrun.py | 4 +- .../test_recovery_slip39_basic.py | 24 +- .../test_recovery_slip39_basic_dryrun.py | 4 +- .../reset_recovery/test_reset_backup.py | 4 +- .../reset_recovery/test_reset_bip39_t2.py | 18 +- .../test_reset_recovery_bip39.py | 6 +- .../test_reset_recovery_slip39_advanced.py | 2 +- .../test_reset_recovery_slip39_basic.py | 4 +- .../test_reset_slip39_advanced.py | 4 +- .../reset_recovery/test_reset_slip39_basic.py | 4 +- tests/device_tests/ripple/test_get_address.py | 2 +- tests/device_tests/solana/test_sign_tx.py | 2 +- tests/device_tests/stellar/test_stellar.py | 2 +- tests/device_tests/test_autolock.py | 21 +- tests/device_tests/test_cancel.py | 4 +- tests/device_tests/test_language.py | 26 +- tests/device_tests/test_msg_applysettings.py | 163 ++--- tests/device_tests/test_msg_backup_device.py | 11 +- .../test_msg_change_wipe_code_t1.py | 16 +- .../test_msg_change_wipe_code_t2.py | 40 +- tests/device_tests/test_msg_changepin_t1.py | 32 +- tests/device_tests/test_msg_changepin_t2.py | 36 +- tests/device_tests/test_msg_ping.py | 9 +- tests/device_tests/test_msg_wipedevice.py | 28 +- tests/device_tests/test_pin.py | 14 +- tests/device_tests/test_protection_levels.py | 16 +- tests/device_tests/test_repeated_backup.py | 29 +- tests/device_tests/test_sdcard.py | 14 +- tests/device_tests/test_session.py | 26 +- .../test_session_id_and_passphrase.py | 18 +- tests/device_tests/tezos/test_getaddress.py | 2 +- tests/device_tests/tezos/test_sign_tx.py | 4 +- .../webauthn/test_msg_webauthn.py | 4 +- tests/device_tests/zcash/test_sign_tx.py | 32 +- tests/input_flows.py | 12 +- tests/persistence_tests/test_wipe_code.py | 16 +- tests/translations.py | 2 +- 89 files changed, 979 insertions(+), 1068 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index cd021433b2..2dbec291ea 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -788,10 +788,10 @@ class DebugUI: def __init__(self, debuglink: DebugLink) -> None: self.debuglink = debuglink + self.pins: t.Iterator[str] | None = None self.clear() def clear(self) -> None: - self.pins: t.Iterator[str] | None = None self.passphrase = None self.input_flow: t.Union[ t.Generator[None, messages.ButtonRequest, None], object, None @@ -947,7 +947,6 @@ class SessionDebugWrapper(Session): if isinstance(session, SessionDebugWrapper): raise Exception("Cannot wrap already wrapped session!") self.__dict__["_session"] = session - self.reset_debug_features() def __getattr__(self, name: str) -> t.Any: return getattr(self._session, name) @@ -962,61 +961,24 @@ class SessionDebugWrapper(Session): def protocol_version(self) -> int: return self.client.protocol_version + @property + def debug_client(self) -> TrezorClientDebugLink: + if not isinstance(self.client, TrezorClientDebugLink): + raise Exception("Debug client not available") + return self.client + def _write(self, msg: t.Any) -> None: - print("writing message:", msg.__class__.__name__) - self._session._write(self._filter_message(msg)) + self._session._write(self.debug_client._filter_message(msg)) def _read(self) -> t.Any: - resp = self._filter_message(self._session._read()) - print("reading message:", resp.__class__.__name__) - if self.actual_responses is not None: - self.actual_responses.append(resp) + resp = self.debug_client._filter_message(self._session._read()) + if self.debug_client.actual_responses is not None: + self.debug_client.actual_responses.append(resp) return resp def resume(self) -> None: self._session.resume() - def set_expected_responses( - self, - expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], - ) -> None: - """Set a sequence of expected responses to session calls. - - Within a given with-block, the list of received responses from device must - match the list of expected responses, otherwise an ``AssertionError`` is raised. - - If an expected response is given a field value other than ``None``, that field value - must exactly match the received field value. If a given field is ``None`` - (or unspecified) in the expected response, the received field value is not - checked. - - Each expected response can also be a tuple ``(bool, message)``. In that case, the - expected response is only evaluated if the first field is ``True``. - This is useful for differentiating sequences between Trezor models: - - >>> trezor_one = session.features.model == "1" - >>> session.set_expected_responses([ - >>> messages.ButtonRequest(code=ConfirmOutput), - >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), - >>> messages.Success(), - >>> ]) - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - # make sure all items are (bool, message) tuples - expected_with_validity = ( - e if isinstance(e, tuple) else (True, e) for e in expected - ) - - # only apply those items that are (True, message) - self.expected_responses = [ - MessageFilter.from_message_or_type(expected) - for valid, expected in expected_with_validity - if valid - ] - self.actual_responses = [] - def lock(self) -> None: """Lock the device. @@ -1037,6 +999,214 @@ class SessionDebugWrapper(Session): btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH) self.refresh_features() + +class TrezorClientDebugLink(TrezorClient): + # This class implements automatic responses + # and other functionality for unit tests + # for various callbacks, created in order + # to automatically pass unit tests. + # + # This mixing should be used only for purposes + # of unit testing, because it will fail to work + # without special DebugLink interface provided + # by the device. + + def __init__( + self, + transport: Transport, + auto_interact: bool = True, + open_transport: bool = True, + debug_transport: Transport | None = None, + ) -> None: + try: + debug_transport = debug_transport or transport.find_debug() + self.debug = DebugLink(debug_transport, auto_interact) + if open_transport: + self.debug.open() + # try to open debuglink, see if it works + assert self.debug.transport.ping() + except Exception: + if not auto_interact: + self.debug = NullDebugLink() + else: + raise + + if open_transport: + transport.open() + + # set transport explicitly so that sync_responses can work + super().__init__(transport) + + self.transport = transport + self.ui: DebugUI = DebugUI(self.debug) + + def get_pin(_msg: messages.PinMatrixRequest) -> str: + try: + pin = self.ui.get_pin() + except Cancelled: + raise + return pin + + self.pin_callback = get_pin + self.button_callback = self.ui.button_request + + self.sync_responses() + + # So that we can choose right screenshotting logic (T1 vs TT) + # and know the supported debug capabilities + self.debug.model = self.model + self.debug.version = self.version + + self.reset_debug_features() + + @property + def layout_type(self) -> LayoutType: + return self.debug.layout_type + + def get_new_client(self) -> TrezorClientDebugLink: + new_client = TrezorClientDebugLink( + self.transport, + self.debug.allow_interactions, + open_transport=False, + debug_transport=self.debug.transport, + ) + new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir + new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory + new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter + return new_client + + def close_transport(self) -> None: + self.transport.close() + self.debug.close() + + def lock(self) -> None: + s = self.get_seedless_session() + s.lock() + + def get_session( + self, + passphrase: str | object = "", + derive_cardano: bool = False, + ) -> SessionDebugWrapper: + if isinstance(passphrase, str): + passphrase = Mnemonic.normalize_string(passphrase) + session = SessionDebugWrapper( + super().get_session( + passphrase, + derive_cardano, + ) + ) + return session + + # FIXME: can be deleted + def get_seedless_session( + self, *args: t.Any, **kwargs: t.Any + ) -> SessionDebugWrapper: + session = super().get_seedless_session(*args, **kwargs) + if not isinstance(session, SessionDebugWrapper): + session = SessionDebugWrapper(session) + return session + + def watch_layout(self, watch: bool = True) -> None: + """Enable or disable watching layout changes. + + Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before + using `debug.wait_layout()`, otherwise layout changes are not reported. + """ + if self.version >= (2, 3, 2): + # version check is necessary because otherwise we cannot reliably detect + # whether and where to wait for reply: + # - T1 reports unknown debuglink messages on the wirelink + # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug + self.debug.watch_layout(watch) + + def use_pin_sequence(self, pins: t.Iterable[str]) -> None: + """Respond to PIN prompts from device with the provided PINs. + The sequence must be at least as long as the expected number of PIN prompts. + """ + self.ui.pins = iter(pins) + + def use_mnemonic(self, mnemonic: str) -> None: + """Use the provided mnemonic to respond to device. + Only applies to T1, where device prompts the host for mnemonic words.""" + self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") + + def sync_responses(self) -> None: + """Synchronize Trezor device receiving with caller. + + When a failed test does not read out the response, the next caller will write + a request, but read the previous response -- while the device had already sent + and placed into queue the new response. + + This function will call `Ping` and read responses until it locates a `Success` + with the expected text. This means that we are reading up-to-date responses. + """ + import secrets + + # Start by canceling whatever is on screen. This will work to cancel T1 PIN + # prompt, which is in TINY mode and does not respond to `Ping`. + if self.protocol_version is ProtocolVersion.V1: + assert isinstance(self.protocol, ProtocolV1Channel) + self.protocol.write(messages.Cancel()) + resp = self.protocol.read() + message = "SYNC" + secrets.token_hex(8) + self.protocol.write(messages.Ping(message=message)) + while resp != messages.Success(message=message): + try: + resp = self.protocol.read() + except Exception: + pass + + def mnemonic_callback(self, _) -> str: + word, pos = self.debug.read_recovery_word() + if word: + return word + if pos: + return self.mnemonic[pos - 1] + + raise RuntimeError("Unexpected call") + + def set_expected_responses( + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], + ) -> None: + """Set a sequence of expected responses to session calls. + + Within a given with-block, the list of received responses from device must + match the list of expected responses, otherwise an ``AssertionError`` is raised. + + If an expected response is given a field value other than ``None``, that field value + must exactly match the received field value. If a given field is ``None`` + (or unspecified) in the expected response, the received field value is not + checked. + + Each expected response can also be a tuple ``(bool, message)``. In that case, the + expected response is only evaluated if the first field is ``True``. + This is useful for differentiating sequences between Trezor models: + + >>> trezor_one = session.features.model == "1" + >>> client.set_expected_responses([ + >>> messages.ButtonRequest(code=ConfirmOutput), + >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), + >>> messages.Success(), + >>> ]) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + # make sure all items are (bool, message) tuples + expected_with_validity = ( + e if isinstance(e, tuple) else (True, e) for e in expected + ) + + # only apply those items that are (True, message) + self.expected_responses = [ + MessageFilter.from_message_or_type(expected) + for valid, expected in expected_with_validity + if valid + ] + self.actual_responses = [] + def set_filter( self, message_type: t.Type[protobuf.MessageType], @@ -1069,6 +1239,7 @@ class SessionDebugWrapper(Session): Clears all debugging state that might have been modified by a testcase. """ + self.ui.clear() self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None @@ -1077,7 +1248,7 @@ class SessionDebugWrapper(Session): t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} - def __enter__(self) -> "SessionDebugWrapper": + def __enter__(self) -> "TrezorClientDebugLink": # For usage in with/expected_responses if self.in_with_statement: raise RuntimeError("Do not nest!") @@ -1092,10 +1263,8 @@ class SessionDebugWrapper(Session): actual_responses = self.actual_responses # grab a copy of the inputflow generator to raise an exception through it - if isinstance(self.client, TrezorClientDebugLink) and isinstance( - self.client.ui, DebugUI - ): - input_flow = self.client.ui.input_flow + if isinstance(self.ui, DebugUI): + input_flow = self.ui.input_flow else: input_flow = None @@ -1105,7 +1274,6 @@ class SessionDebugWrapper(Session): # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. @@ -1165,205 +1333,9 @@ class SessionDebugWrapper(Session): output.append("") return output - -class TrezorClientDebugLink(TrezorClient): - # This class implements automatic responses - # and other functionality for unit tests - # for various callbacks, created in order - # to automatically pass unit tests. - # - # This mixing should be used only for purposes - # of unit testing, because it will fail to work - # without special DebugLink interface provided - # by the device. - - def __init__( - self, - transport: Transport, - auto_interact: bool = True, - open_transport: bool = True, - debug_transport: Transport | None = None, - ) -> None: - try: - debug_transport = debug_transport or transport.find_debug() - self.debug = DebugLink(debug_transport, auto_interact) - if open_transport: - self.debug.open() - # try to open debuglink, see if it works - assert self.debug.transport.ping() - except Exception: - if not auto_interact: - self.debug = NullDebugLink() - else: - raise - - if open_transport: - transport.open() - - # set transport explicitly so that sync_responses can work - super().__init__(transport) - - self.transport = transport - self.ui: DebugUI = DebugUI(self.debug) - - self.sync_responses() - - # So that we can choose right screenshotting logic (T1 vs TT) - # and know the supported debug capabilities - self.debug.model = self.model - self.debug.version = self.version - - @property - def layout_type(self) -> LayoutType: - return self.debug.layout_type - - def get_new_client(self) -> TrezorClientDebugLink: - new_client = TrezorClientDebugLink( - self.transport, - self.debug.allow_interactions, - open_transport=False, - debug_transport=self.debug.transport, - ) - new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir - new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory - new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter - return new_client - - def reset_debug_features(self) -> None: - """ - Prepare the debugging client for a new testcase. - - Clears all debugging state that might have been modified by a testcase. - """ - self.ui: DebugUI = DebugUI(self.debug) - self.in_with_statement = False - - def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - session._write(messages.ButtonAck()) - self.ui.button_request(msg) - return session._read() - - def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any: - try: - pin = self.ui.get_pin(msg.type) - except Cancelled: - session.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - session.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - resp = session.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise PinException(resp.code, resp.message) - else: - return resp - - def passphrase_callback( - self, session: Session, msg: messages.PassphraseRequest - ) -> t.Any: - available_on_device = ( - Capability.PassphraseEntry in session.features.capabilities - ) - - def send_passphrase( - passphrase: str | None = None, on_device: bool | None = None - ) -> MessageType: - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = session.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - if resp.state is not None: - session.id = resp.state - else: - raise RuntimeError("Object resp.state is None") - resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - if isinstance(session, SessionDebugWrapper): - passphrase = self.ui.get_passphrase( - available_on_device=available_on_device - ) - if passphrase is None: - passphrase = session.passphrase - else: - raise NotImplementedError - except Cancelled: - session.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - session.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if passphrase is None: - passphrase = "" - if not isinstance(passphrase, str): - raise RuntimeError(f"Passphrase must be a str {type(passphrase)}") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - session.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - def close_transport(self) -> None: - self.transport.close() - self.debug.close() - - def lock(self) -> None: - s = self.get_seedless_session() - s.lock() - - def get_session( - self, - passphrase: str | object | None = None, - derive_cardano: bool = False, - session_id: bytes | None = None, - ) -> SessionDebugWrapper: - if isinstance(passphrase, str): - passphrase = Mnemonic.normalize_string(passphrase) - session = SessionDebugWrapper( - super().get_session( - passphrase, derive_cardano, session_id, should_derive=False - ) - ) - session.passphrase = passphrase - return session - - def get_seedless_session( - self, *args: t.Any, **kwargs: t.Any - ) -> SessionDebugWrapper: - session = super().get_seedless_session(*args, **kwargs) - if not isinstance(session, SessionDebugWrapper): - session = SessionDebugWrapper(session) - return session - - def resume_session(self, session: Session) -> SessionDebugWrapper: - if isinstance(session, SessionDebugWrapper): - session._session = super().resume_session(session._session) - return session - else: - return SessionDebugWrapper(super().resume_session(session)) - def set_input_flow( - self, input_flow: InputFlowType | t.Callable[[], InputFlowType] + self, + input_flow: InputFlowType | t.Callable[[], InputFlowType], ) -> None: """Configure a sequence of input events for the current with-block. @@ -1387,7 +1359,7 @@ class TrezorClientDebugLink(TrezorClient): >>> >>> with client: >>> client.set_input_flow(input_flow) - >>> some_call(client) + >>> some_call(session) """ if not self.in_with_statement: raise RuntimeError("Must be called inside 'with' statement") @@ -1397,109 +1369,9 @@ class TrezorClientDebugLink(TrezorClient): if not hasattr(input_flow, "send"): raise RuntimeError("input_flow should be a generator function") self.ui.input_flow = input_flow + next(input_flow) # start the generator - def watch_layout(self, watch: bool = True) -> None: - """Enable or disable watching layout changes. - - Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before - using `debug.wait_layout()`, otherwise layout changes are not reported. - """ - if self.version >= (2, 3, 2): - # version check is necessary because otherwise we cannot reliably detect - # whether and where to wait for reply: - # - T1 reports unknown debuglink messages on the wirelink - # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug - self.debug.watch_layout(watch) - - def __enter__(self) -> "TrezorClientDebugLink": - # For usage in with/expected_responses - if self.in_with_statement: - raise RuntimeError("Do not nest!") - self.in_with_statement = True - return self - - def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - # grab a copy of the inputflow generator to raise an exception through it - if isinstance(self.ui, DebugUI): - input_flow = self.ui.input_flow - else: - input_flow = None - - self.reset_debug_features() - - if exc_type is not None and isinstance(input_flow, t.Generator): - # Propagate the exception through the input flow, so that we see in - # traceback where it is stuck. - input_flow.throw(exc_type, value, traceback) - - def use_pin_sequence(self, pins: t.Iterable[str]) -> None: - """Respond to PIN prompts from device with the provided PINs. - The sequence must be at least as long as the expected number of PIN prompts. - """ - self.ui.pins = iter(pins) - - def use_mnemonic(self, mnemonic: str) -> None: - """Use the provided mnemonic to respond to device. - Only applies to T1, where device prompts the host for mnemonic words.""" - self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - - @staticmethod - def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: - start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) - stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) - output: list[str] = [] - output.append("Expected responses:") - if start_at > 0: - output.append(f" (...{start_at} previous responses omitted)") - for i in range(start_at, stop_at): - exp = expected[i] - prefix = " " if i != current else ">>> " - output.append(textwrap.indent(exp.to_string(), prefix)) - if stop_at < len(expected): - omitted = len(expected) - stop_at - output.append(f" (...{omitted} following responses omitted)") - - output.append("") - return output - - def sync_responses(self) -> None: - """Synchronize Trezor device receiving with caller. - - When a failed test does not read out the response, the next caller will write - a request, but read the previous response -- while the device had already sent - and placed into queue the new response. - - This function will call `Ping` and read responses until it locates a `Success` - with the expected text. This means that we are reading up-to-date responses. - """ - import secrets - - # Start by canceling whatever is on screen. This will work to cancel T1 PIN - # prompt, which is in TINY mode and does not respond to `Ping`. - if self.protocol_version is ProtocolVersion.V1: - assert isinstance(self.protocol, ProtocolV1Channel) - self.protocol.write(messages.Cancel()) - resp = self.protocol.read() - message = "SYNC" + secrets.token_hex(8) - self.protocol.write(messages.Ping(message=message)) - while resp != messages.Success(message=message): - try: - resp = self.protocol.read() - except Exception: - pass - - def mnemonic_callback(self, _) -> str: - word, pos = self.debug.read_recovery_word() - if word: - return word - if pos: - return self.mnemonic[pos - 1] - - raise RuntimeError("Unexpected call") - def load_device( session: "Session", diff --git a/tests/burn_tests/burntest_t2.py b/tests/burn_tests/burntest_t2.py index 5f1048254c..356734cf75 100755 --- a/tests/burn_tests/burntest_t2.py +++ b/tests/burn_tests/burntest_t2.py @@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str): if __name__ == "__main__": wirelink = get_device() client = Client(wirelink) - client.open() + session = client.get_seedless_session() i = 0 @@ -83,3 +83,5 @@ if __name__ == "__main__": print(f"iteration {i}") i = i + 1 + + wirelink.close() diff --git a/tests/click_tests/test_autolock.py b/tests/click_tests/test_autolock.py index 152021a57c..c82c11fcc2 100644 --- a/tests/click_tests/test_autolock.py +++ b/tests/click_tests/test_autolock.py @@ -195,13 +195,15 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa assert TR.send__total_amount in layout.text_content() assert "0.0039 BTC" in layout.text_content() + client = session.client + def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - session.set_filter(messages.TxAck, None) + client.set_filter(messages.TxAck, None) return msg - with session, device_handler.client: - session.set_filter(messages.TxAck, sleepy_filter) + with client: + client.set_filter(messages.TxAck, sleepy_filter) # confirm transaction if debug.layout_type is LayoutType.Bolt: debug.click(debug.screen_buttons.ok(), hold_ms=1000) @@ -546,15 +548,17 @@ def test_autolock_does_not_interrupt_preauthorized( no_fee_indices=[], ) + client = session.client + def sleepy_filter(msg: MessageType) -> MessageType: time.sleep(10.1) - session.set_filter(messages.SignTx, None) + client.set_filter(messages.SignTx, None) return msg - with session: + with client: # Start DoPreauthorized flow when device is unlocked. Wait 10s before # delivering SignTx, by that time autolock timer should have fired. - session.set_filter(messages.SignTx, sleepy_filter) + client.set_filter(messages.SignTx, sleepy_filter) device_handler.run_with_provided_session( session, btc.sign_tx, diff --git a/tests/device_handler.py b/tests/device_handler.py index 0f69bff4d2..a3dc61314d 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1 class NullUI: + @staticmethod + def clear(*args, **kwargs): + pass + @staticmethod def button_request(code): pass diff --git a/tests/device_tests/binance/test_get_address.py b/tests/device_tests/binance/test_get_address.py index 6b5a024767..824aed31f4 100644 --- a/tests/device_tests/binance/test_get_address.py +++ b/tests/device_tests/binance/test_get_address.py @@ -52,7 +52,7 @@ def test_binance_get_address_chunkify_details( # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 with session.client as client: - IF = InputFlowShowAddressQRCode(client) + IF = InputFlowShowAddressQRCode(session.client) client.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True diff --git a/tests/device_tests/binance/test_get_public_key.py b/tests/device_tests/binance/test_get_public_key.py index f65baa5dd8..6861449099 100644 --- a/tests/device_tests/binance/test_get_public_key.py +++ b/tests/device_tests/binance/test_get_public_key.py @@ -33,7 +33,7 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0") ) def test_binance_get_public_key(session: Session): with session.client as client: - IF = InputFlowShowXpubQRCode(client) + IF = InputFlowShowXpubQRCode(session.client) client.set_input_flow(IF.get()) sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) assert ( diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 65157487f4..bfa4ad3c5f 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -66,7 +66,7 @@ def test_sign_tx(session: Session, chunkify: bool): commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") with session.client as client: - client.use_pin_sequence([PIN]) + session.client.use_pin_sequence([PIN]) btc.authorize_coinjoin( session, coordinator="www.example.com", @@ -80,8 +80,8 @@ def test_sign_tx(session: Session, chunkify: bool): session.call(messages.LockDevice()) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( @@ -94,8 +94,8 @@ def test_sign_tx(session: Session, chunkify: bool): preauthorized=True, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [messages.PreauthorizedRequest, messages.OwnershipProof] ) btc.get_ownership_proof( @@ -207,8 +207,8 @@ def test_sign_tx(session: Session, chunkify: bool): no_fee_indices=[], ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.PreauthorizedRequest(), request_input(0), @@ -452,8 +452,8 @@ def test_sign_tx_spend(session: Session): prev_txes=TX_CACHE_TESTNET, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -526,8 +526,8 @@ def test_sign_tx_migration(session: Session): prev_txes=TX_CACHE_TESTNET, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -666,8 +666,8 @@ def test_get_public_key(session: Session): ) # Get unlock path MAC. - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, @@ -689,8 +689,8 @@ def test_get_public_key(session: Session): ) # Ensure that user does not need to confirm access when path unlock is requested with MAC. - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.UnlockedPathRequest, messages.PublicKey, @@ -720,8 +720,8 @@ def test_get_address(session: Session): ) # Unlock CoinJoin path. - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.ButtonRequest(code=B.Other), messages.UnlockedPathRequest, diff --git a/tests/device_tests/bitcoin/test_bcash.py b/tests/device_tests/bitcoin/test_bcash.py index d1f0129741..ebe264a225 100644 --- a/tests/device_tests/bitcoin/test_bcash.py +++ b/tests/device_tests/bitcoin/test_bcash.py @@ -72,8 +72,8 @@ def test_send_bch_change(session: Session): amount=73_452, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -124,8 +124,8 @@ def test_send_bch_nochange(session: Session): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -182,8 +182,8 @@ def test_send_bch_oldaddr(session: Session): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -252,9 +252,9 @@ def test_attack_change_input(session: Session): return msg - with session: - session.set_filter(messages.TxAck, attack_processor) - session.set_expected_responses( + with session.client as client: + client.set_filter(messages.TxAck, attack_processor) + client.set_expected_responses( [ request_input(0), request_output(0), @@ -327,8 +327,8 @@ def test_send_bch_multisig_wrongchange(session: Session): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=23_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -395,8 +395,8 @@ def test_send_bch_multisig_change(session: Session): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=24_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -434,8 +434,8 @@ def test_send_bch_multisig_change(session: Session): ) out2.address_n[2] = H_(1) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -496,8 +496,8 @@ def test_send_bch_external_presigned(session: Session): amount=1_934_960, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), diff --git a/tests/device_tests/bitcoin/test_bgold.py b/tests/device_tests/bitcoin/test_bgold.py index 831ea216cb..a49e24fc70 100644 --- a/tests/device_tests/bitcoin/test_bgold.py +++ b/tests/device_tests/bitcoin/test_bgold.py @@ -71,8 +71,8 @@ def test_send_bitcoin_gold_change(session: Session): amount=1_252_382_934 - 1_896_050 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -124,8 +124,8 @@ def test_send_bitcoin_gold_nochange(session: Session): amount=1_252_382_934 + 38_448_607 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -193,9 +193,9 @@ def test_attack_change_input(session: Session): return msg - with session: - session.set_filter(messages.TxAck, attack_processor) - session.set_expected_responses( + with session.client as client: + client.set_filter(messages.TxAck, attack_processor) + client.set_expected_responses( [ request_input(0), request_output(0), @@ -254,8 +254,8 @@ def test_send_btg_multisig_change(session: Session): script_type=messages.OutputScriptType.PAYTOMULTISIG, amount=1_252_382_934 - 24_000 - 1_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -293,8 +293,8 @@ def test_send_btg_multisig_change(session: Session): ) out2.address_n[2] = H_(1) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -347,8 +347,8 @@ def test_send_p2sh(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -400,8 +400,8 @@ def test_send_p2sh_witness_change(session: Session): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=1_252_382_934 - 11_000 - 12_300_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -460,8 +460,8 @@ def test_send_multisig_1(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -484,7 +484,7 @@ def test_send_multisig_1(session: Session): inp1.multisig.signatures[0] = signatures[0] # sign with third key inp1.address_n[2] = H_(3) - session.set_expected_responses( + client.set_expected_responses( [ request_input(0), request_output(0), @@ -537,7 +537,7 @@ def test_send_mixed_inputs(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API ) @@ -577,8 +577,8 @@ def test_send_btg_external_presigned(session: Session): amount=1_252_382_934 + 58_456 - 1_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), diff --git a/tests/device_tests/bitcoin/test_dash.py b/tests/device_tests/bitcoin/test_dash.py index 06b335c148..f0e83dac7a 100644 --- a/tests/device_tests/bitcoin/test_dash.py +++ b/tests/device_tests/bitcoin/test_dash.py @@ -57,8 +57,8 @@ def test_send_dash(session: Session): amount=999_999_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -106,8 +106,8 @@ def test_send_dash_dip2_input(session: Session): amount=95_000_000, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), diff --git a/tests/device_tests/bitcoin/test_decred.py b/tests/device_tests/bitcoin/test_decred.py index 204d055928..65cabec8a4 100644 --- a/tests/device_tests/bitcoin/test_decred.py +++ b/tests/device_tests/bitcoin/test_decred.py @@ -76,8 +76,8 @@ def test_send_decred(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -133,8 +133,8 @@ def test_purchase_ticket_decred(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -197,8 +197,8 @@ def test_spend_from_stake_generation_and_revocation_decred(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -278,8 +278,8 @@ def test_send_decred_change(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -384,8 +384,8 @@ def test_decred_multisig_change(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), diff --git a/tests/device_tests/bitcoin/test_descriptors.py b/tests/device_tests/bitcoin/test_descriptors.py index 7a077b2052..2da50b3e16 100644 --- a/tests/device_tests/bitcoin/test_descriptors.py +++ b/tests/device_tests/bitcoin/test_descriptors.py @@ -169,7 +169,7 @@ def test_descriptors( session: Session, coin, account, purpose, script_type, descriptors ): with session.client as client: - IF = InputFlowShowXpubQRCode(client) + IF = InputFlowShowXpubQRCode(session.client) client.set_input_flow(IF.get()) address_n = _address_n(purpose, coin, account, script_type) @@ -192,8 +192,8 @@ def test_descriptors_trezorlib( session: Session, coin, account, purpose, script_type, descriptors ): with session.client as client: - if client.model != models.T1B1: - IF = InputFlowShowXpubQRCode(client) + if session.client.model != models.T1B1: + IF = InputFlowShowXpubQRCode(session.client) client.set_input_flow(IF.get()) res = btc_cli._get_descriptor( session, coin, account, purpose, script_type, show_display=True diff --git a/tests/device_tests/bitcoin/test_getaddress.py b/tests/device_tests/bitcoin/test_getaddress.py index 3c8a2fbc9d..55288aaed3 100644 --- a/tests/device_tests/bitcoin/test_getaddress.py +++ b/tests/device_tests/bitcoin/test_getaddress.py @@ -272,7 +272,7 @@ def test_multisig(session: Session): for nr in range(1, 4): with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) assert ( btc.get_address( @@ -321,9 +321,9 @@ def test_multisig_missing(session: Session, show_display): ) for multisig in (multisig1, multisig2): - with session.client as client, pytest.raises(TrezorFailure): + with pytest.raises(TrezorFailure), session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.get_address( session, @@ -347,7 +347,7 @@ def test_bch_multisig(session: Session): for nr in range(1, 4): with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) assert ( btc.get_address( @@ -396,8 +396,8 @@ def test_invalid_path(session: Session): def test_unknown_path(session: Session): UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0") - with session: - session.set_expected_responses([messages.Failure]) + with session.client as client: + client.set_expected_responses([messages.Failure]) with pytest.raises(TrezorFailure, match="Forbidden key path"): # account number is too high @@ -406,8 +406,8 @@ def test_unknown_path(session: Session): # disable safety checks device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) - with session, session.client as client: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.ButtonRequest( code=messages.ButtonRequestType.UnknownDerivationPath @@ -417,14 +417,14 @@ def test_unknown_path(session: Session): ] ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) # try again with a warning btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) - with session: + with session.client as client: # no warning is displayed when the call is silent - session.set_expected_responses([messages.Address]) + client.set_expected_responses([messages.Address]) btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False) @@ -455,9 +455,9 @@ def test_multisig_different_paths(session: Session): with pytest.raises( Exception, match="Using different paths for different xpubs is not allowed" ): - with session.client as client, session: + with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.get_address( session, @@ -471,7 +471,7 @@ def test_multisig_different_paths(session: Session): device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.get_address( session, diff --git a/tests/device_tests/bitcoin/test_getaddress_segwit.py b/tests/device_tests/bitcoin/test_getaddress_segwit.py index b1e3affac7..5b23d53999 100644 --- a/tests/device_tests/bitcoin/test_getaddress_segwit.py +++ b/tests/device_tests/bitcoin/test_getaddress_segwit.py @@ -76,7 +76,7 @@ def test_show_segwit(session: Session): def test_show_segwit_altcoin(session: Session): with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) assert ( btc.get_address( diff --git a/tests/device_tests/bitcoin/test_getaddress_show.py b/tests/device_tests/bitcoin/test_getaddress_show.py index 464c9cc70e..86aacc4348 100644 --- a/tests/device_tests/bitcoin/test_getaddress_show.py +++ b/tests/device_tests/bitcoin/test_getaddress_show.py @@ -89,7 +89,7 @@ def test_show_tt( address: str, ): with session.client as client: - IF = InputFlowShowAddressQRCode(client) + IF = InputFlowShowAddressQRCode(session.client) client.set_input_flow(IF.get()) assert ( btc.get_address( @@ -110,7 +110,7 @@ def test_show_cancel( session: Session, path: str, script_type: messages.InputScriptType, address: str ): with session.client as client, pytest.raises(Cancelled): - IF = InputFlowShowAddressQRCodeCancel(client) + IF = InputFlowShowAddressQRCodeCancel(session.client) client.set_input_flow(IF.get()) btc.get_address( session, @@ -159,7 +159,7 @@ def test_show_multisig_3(session: Session): for i in [1, 2, 3]: with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) assert ( btc.get_address( @@ -273,11 +273,11 @@ def test_show_multisig_xpubs( ) for i in range(3): - with session, session.client as client: - IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) + with session.client as client: + IF = InputFlowShowMultisigXPUBs(session.client, address, xpubs, i) client.set_input_flow(IF.get()) - client.debug.synchronize_at("Homescreen") - client.watch_layout() + session.client.debug.synchronize_at("Homescreen") + session.client.watch_layout() btc.get_address( session, "Bitcoin", @@ -316,7 +316,7 @@ def test_show_multisig_15(session: Session): for i in range(15): with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) assert ( btc.get_address( diff --git a/tests/device_tests/bitcoin/test_getpublickey.py b/tests/device_tests/bitcoin/test_getpublickey.py index e013e6f71c..34e8f01b77 100644 --- a/tests/device_tests/bitcoin/test_getpublickey.py +++ b/tests/device_tests/bitcoin/test_getpublickey.py @@ -120,7 +120,7 @@ def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub): @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): with session.client as client: - IF = InputFlowShowXpubQRCode(client) + IF = InputFlowShowXpubQRCode(session.client) client.set_input_flow(IF.get()) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub @@ -158,7 +158,7 @@ def test_get_public_node_show_legacy( client.debug.press_yes() # finish the flow yield - with client: + with session.client as client: # test XPUB display flow (without showing QR code) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) assert res.xpub == xpub diff --git a/tests/device_tests/bitcoin/test_komodo.py b/tests/device_tests/bitcoin/test_komodo.py index 111acefc6f..b970239593 100644 --- a/tests/device_tests/bitcoin/test_komodo.py +++ b/tests/device_tests/bitcoin/test_komodo.py @@ -61,8 +61,8 @@ def test_one_one_fee_sapling(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -125,8 +125,8 @@ def test_one_one_rewards_claim(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), diff --git a/tests/device_tests/bitcoin/test_multisig.py b/tests/device_tests/bitcoin/test_multisig.py index 5888409d86..d8d680e1bc 100644 --- a/tests/device_tests/bitcoin/test_multisig.py +++ b/tests/device_tests/bitcoin/test_multisig.py @@ -101,8 +101,8 @@ def test_2_of_3(session: Session, chunkify: bool): request_finished(), ] - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) # Now we have first signature signatures1, _ = btc.sign_tx( @@ -143,8 +143,8 @@ def test_2_of_3(session: Session, chunkify: bool): multisig=multisig, ) - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET ) @@ -362,7 +362,7 @@ def test_15_of_15(session: Session): multisig=multisig, ) - with session: + with session.client: sig, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -424,6 +424,7 @@ def test_attack_change_input(session: Session): attacker to provide a 1-of-2 multisig change address. When `input_real` is provided in the signing phase, an error must occur. """ + client = session.client address_n = parse_path("m/48h/1h/0h/1h/0/0") # 2NErUdruXmM8o8bQySrzB3WdBRcmc5br4E8 attacker_multisig_public_key = bytes.fromhex( "03653a148b68584acb97947344a7d4fd6a6f8b8485cad12987ff8edac874268088" @@ -475,7 +476,7 @@ def test_attack_change_input(session: Session): ) # Transaction can be signed without the attack processor - with session.client as client: + with client: if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) @@ -497,8 +498,8 @@ def test_attack_change_input(session: Session): attack_count -= 1 return msg - with session: - session.set_filter(messages.TxAck, attack_processor) + with client: + client.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( session, diff --git a/tests/device_tests/bitcoin/test_multisig_change.py b/tests/device_tests/bitcoin/test_multisig_change.py index efc4f42d56..00a732a0dd 100644 --- a/tests/device_tests/bitcoin/test_multisig_change.py +++ b/tests/device_tests/bitcoin/test_multisig_change.py @@ -263,8 +263,8 @@ def test_external_external(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses(_responses(session, INP1, INP2)) + with session.client as client: + client.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( session, "Bitcoin", @@ -288,8 +288,8 @@ def test_external_internal(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, session.client as client: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( _responses( session, INP1, @@ -299,7 +299,7 @@ def test_external_internal(session: Session): ) ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.sign_tx( session, @@ -324,8 +324,8 @@ def test_internal_external(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, session.client as client: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( _responses( session, INP1, @@ -335,7 +335,7 @@ def test_internal_external(session: Session): ) ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.sign_tx( session, @@ -360,8 +360,8 @@ def test_multisig_external_external(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses(_responses(session, INP1, INP2)) + with session.client as client: + client.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( session, "Bitcoin", @@ -393,8 +393,8 @@ def test_multisig_change_match_first(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( _responses(session, INP1, INP2, change_indices=[1]) ) btc.sign_tx( @@ -428,8 +428,8 @@ def test_multisig_change_match_second(session: Session): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( _responses(session, INP1, INP2, change_indices=[2]) ) btc.sign_tx( @@ -464,8 +464,8 @@ def test_sorted_multisig_change_match_first(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( _responses(session, INP4, INP5, change_indices=[1]) ) btc.sign_tx( @@ -499,8 +499,8 @@ def test_multisig_mismatch_multisig_change(session: Session): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with session: - session.set_expected_responses(_responses(session, INP1, INP2)) + with session.client as client: + client.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( session, "Bitcoin", @@ -532,8 +532,8 @@ def test_sorted_multisig_mismatch_multisig_change(session: Session): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with session: - session.set_expected_responses(_responses(session, INP4, INP5)) + with session.client as client: + client.set_expected_responses(_responses(session, INP4, INP5)) btc.sign_tx( session, "Bitcoin", @@ -568,8 +568,8 @@ def test_multisig_mismatch_multisig_change_different_paths(session: Session): script_type=messages.OutputScriptType.PAYTOMULTISIG, ) - with session: - session.set_expected_responses(_responses(session, INP1, INP2)) + with session.client as client: + client.set_expected_responses(_responses(session, INP1, INP2)) btc.sign_tx( session, "Bitcoin", @@ -601,8 +601,8 @@ def test_multisig_mismatch_inputs(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses(_responses(session, INP1, INP3)) + with session.client as client: + client.set_expected_responses(_responses(session, INP1, INP3)) btc.sign_tx( session, "Bitcoin", @@ -635,8 +635,8 @@ def test_sorted_multisig_mismatch_inputs(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses(_responses(session, INP4, INP6)) + with session.client as client: + client.set_expected_responses(_responses(session, INP4, INP6)) btc.sign_tx( session, "Bitcoin", diff --git a/tests/device_tests/bitcoin/test_nonstandard_paths.py b/tests/device_tests/bitcoin/test_nonstandard_paths.py index 77d57aa951..8366f85055 100644 --- a/tests/device_tests/bitcoin/test_nonstandard_paths.py +++ b/tests/device_tests/bitcoin/test_nonstandard_paths.py @@ -115,7 +115,7 @@ def test_getaddress( for script_type in script_types: with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) res = btc.get_address( session, @@ -136,7 +136,7 @@ def test_signmessage( for script_type in script_types: with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) sig = btc.sign_message( @@ -177,7 +177,7 @@ def test_signtx( with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} @@ -204,7 +204,7 @@ def test_getaddress_multisig( with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) address = btc.get_address( session, @@ -263,7 +263,7 @@ def test_signtx_multisig(session: Session, paths: list[str], address_index: list with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) sig, _ = btc.sign_tx( session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} diff --git a/tests/device_tests/bitcoin/test_op_return.py b/tests/device_tests/bitcoin/test_op_return.py index 0aa8acb080..2a10bbc533 100644 --- a/tests/device_tests/bitcoin/test_op_return.py +++ b/tests/device_tests/bitcoin/test_op_return.py @@ -63,8 +63,8 @@ def test_opreturn(session: Session): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -110,8 +110,8 @@ def test_nonzero_opreturn(session: Session): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) @@ -136,8 +136,8 @@ def test_opreturn_address(session: Session): script_type=messages.OutputScriptType.PAYTOOPRETURN, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [request_input(0), request_output(0), messages.Failure()] ) with pytest.raises( diff --git a/tests/device_tests/bitcoin/test_signmessage.py b/tests/device_tests/bitcoin/test_signmessage.py index bf9ec4e326..dafa45ac52 100644 --- a/tests/device_tests/bitcoin/test_signmessage.py +++ b/tests/device_tests/bitcoin/test_signmessage.py @@ -328,7 +328,7 @@ def test_signmessage_long( signature: str, ): with session.client as client: - IF = InputFlowSignVerifyMessageLong(client) + IF = InputFlowSignVerifyMessageLong(session.client) client.set_input_flow(IF.get()) sig = btc.sign_message( session, @@ -357,7 +357,7 @@ def test_signmessage_info( signature: str, ): with session.client as client, pytest.raises(Cancelled): - IF = InputFlowSignMessageInfo(client) + IF = InputFlowSignMessageInfo(session.client) client.set_input_flow(IF.get()) sig = btc.sign_message( session, @@ -395,7 +395,7 @@ def test_signmessage_pagination(session: Session, message: str, is_long: bool): InputFlowSignVerifyMessageLong if is_long else InputFlowSignMessagePagination - )(client) + )(session.client) client.set_input_flow(IF.get()) btc.sign_message( session, @@ -417,8 +417,8 @@ def test_signmessage_pagination_trailing_newline(session: Session): message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n" # The trailing newline must not cause a new paginated screen to appear. # The UI must be a single dialog without pagination. - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ # expect address confirmation message_filters.ButtonRequest(code=messages.ButtonRequestType.Other), @@ -438,8 +438,8 @@ def test_signmessage_pagination_trailing_newline(session: Session): def test_signmessage_path_warning(session: Session): message = "This is an example of a signed message." - with session, session.client as client: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ # expect a path warning message_filters.ButtonRequest( @@ -451,7 +451,7 @@ def test_signmessage_path_warning(session: Session): ] ) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.sign_message( session, diff --git a/tests/device_tests/bitcoin/test_signtx.py b/tests/device_tests/bitcoin/test_signtx.py index d919d68792..77912bc9f9 100644 --- a/tests/device_tests/bitcoin/test_signtx.py +++ b/tests/device_tests/bitcoin/test_signtx.py @@ -125,8 +125,8 @@ def test_one_one_fee(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -178,8 +178,8 @@ def test_testnet_one_two_fee(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -228,8 +228,8 @@ def test_testnet_fee_high_warning(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -280,8 +280,8 @@ def test_one_two_fee(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -342,8 +342,8 @@ def test_one_three_fee(session: Session, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -413,8 +413,8 @@ def test_two_two(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -557,8 +557,8 @@ def test_lots_of_change(session: Session): request_change_outputs = [request_output(i + 1) for i in range(cnt)] - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -608,8 +608,8 @@ def test_fee_high_warning(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -665,7 +665,7 @@ def test_fee_high_hardfail(session: Session): session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily ) with session.client as client: - IF = InputFlowSignTxHighFee(client) + IF = InputFlowSignTxHighFee(session.client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( @@ -696,8 +696,8 @@ def test_not_enough_funds(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -726,8 +726,8 @@ def test_p2sh(session: Session): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -785,6 +785,7 @@ def test_testnet_big_amount(session: Session): def test_attack_change_outputs(session: Session): # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a + client = session.client inp1 = messages.TxInputType( address_n=parse_path("m/44h/0h/0h/0/55"), # 14nw9rFTWGUncHZjSqpPSJQaptWW7iRRB8 @@ -813,8 +814,8 @@ def test_attack_change_outputs(session: Session): ) # Test if the transaction can be signed normally - with session: - session.set_expected_responses( + with client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -869,11 +870,11 @@ def test_attack_change_outputs(session: Session): return msg - with session, pytest.raises( + with client, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - session.set_filter(messages.TxAck, attack_processor) + client.set_filter(messages.TxAck, attack_processor) btc.sign_tx( session, @@ -924,11 +925,11 @@ def test_attack_modify_change_address(session: Session): return msg - with session, pytest.raises( + with session.client as client, pytest.raises( TrezorFailure, match="Transaction has changed during signing" ): # Set up attack processors - session.set_filter(messages.TxAck, attack_processor) + client.set_filter(messages.TxAck, attack_processor) btc.sign_tx( session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET @@ -982,9 +983,9 @@ def test_attack_change_input_address(session: Session): return msg # Now run the attack, must trigger the exception - with session: - session.set_filter(messages.TxAck, attack_processor) - session.set_expected_responses( + with session.client as client: + client.set_filter(messages.TxAck, attack_processor) + client.set_expected_responses( [ request_input(0), request_output(0), @@ -1033,8 +1034,8 @@ def test_spend_coinbase(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -1091,8 +1092,8 @@ def test_two_changes(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -1150,8 +1151,8 @@ def test_change_on_main_chain_allowed(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -1418,8 +1419,8 @@ def test_lock_time(session: Session, lock_time: int, sequence: int): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -1468,7 +1469,7 @@ def test_lock_time_blockheight(session: Session): ) with session.client as client: - IF = InputFlowLockTimeBlockHeight(client, "499999999") + IF = InputFlowLockTimeBlockHeight(session.client, "499999999") client.set_input_flow(IF.get()) btc.sign_tx( @@ -1507,7 +1508,7 @@ def test_lock_time_datetime(session: Session, lock_time_str: str): lock_time_timestamp = int(lock_time_utc.timestamp()) with session.client as client: - IF = InputFlowLockTimeDatetime(client, lock_time_str) + IF = InputFlowLockTimeDatetime(session.client, lock_time_str) client.set_input_flow(IF.get()) btc.sign_tx( @@ -1539,7 +1540,7 @@ def test_information(session: Session): ) with session.client as client: - IF = InputFlowSignTxInformation(client) + IF = InputFlowSignTxInformation(session.client) client.set_input_flow(IF.get()) btc.sign_tx( @@ -1574,7 +1575,7 @@ def test_information_mixed(session: Session): ) with session.client as client: - IF = InputFlowSignTxInformationMixed(client) + IF = InputFlowSignTxInformationMixed(session.client) client.set_input_flow(IF.get()) btc.sign_tx( @@ -1605,7 +1606,7 @@ def test_information_cancel(session: Session): ) with session.client as client, pytest.raises(Cancelled): - IF = InputFlowSignTxInformationCancel(client) + IF = InputFlowSignTxInformationCancel(session.client) client.set_input_flow(IF.get()) btc.sign_tx( @@ -1653,7 +1654,7 @@ def test_information_replacement(session: Session): ) with session.client as client: - IF = InputFlowSignTxInformationReplacement(client) + IF = InputFlowSignTxInformationReplacement(session.client) client.set_input_flow(IF.get()) btc.sign_tx( diff --git a/tests/device_tests/bitcoin/test_signtx_amount_unit.py b/tests/device_tests/bitcoin/test_signtx_amount_unit.py index 50cc19151b..c889825a54 100644 --- a/tests/device_tests/bitcoin/test_signtx_amount_unit.py +++ b/tests/device_tests/bitcoin/test_signtx_amount_unit.py @@ -61,7 +61,7 @@ def test_signtx_testnet(session: Session, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Testnet", @@ -95,7 +95,7 @@ def test_signtx_btc(session: Session, amount_unit): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Bitcoin", diff --git a/tests/device_tests/bitcoin/test_signtx_external.py b/tests/device_tests/bitcoin/test_signtx_external.py index 4d44e3ec76..f3a80c43b3 100644 --- a/tests/device_tests/bitcoin/test_signtx_external.py +++ b/tests/device_tests/bitcoin/test_signtx_external.py @@ -142,7 +142,7 @@ def test_p2pkh_presigned(session: Session): ) # Test with first input as pre-signed external. - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Testnet", @@ -155,7 +155,7 @@ def test_p2pkh_presigned(session: Session): assert serialized_tx.hex() == expected_tx # Test with second input as pre-signed external. - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Testnet", @@ -216,8 +216,8 @@ def test_p2wpkh_in_p2sh_presigned(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -267,8 +267,8 @@ def test_p2wpkh_in_p2sh_presigned(session: Session): # Test corrupted script hash in scriptsig. inp1.script_sig[10] ^= 1 - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -339,7 +339,7 @@ def test_p2wpkh_presigned(session: Session): ) # Test with second input as pre-signed external. - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Testnet", @@ -399,8 +399,8 @@ def test_p2wsh_external_presigned(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -444,8 +444,8 @@ def test_p2wsh_external_presigned(session: Session): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -509,8 +509,8 @@ def test_p2tr_external_presigned(session: Session): amount=4_600, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -541,8 +541,8 @@ def test_p2tr_external_presigned(session: Session): # Test corrupted signature in witness. inp2.witness[10] ^= 1 - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -610,9 +610,9 @@ def test_p2wpkh_with_proof(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client as client: is_t1 = session.model is models.T1B1 - session.set_expected_responses( + client.set_expected_responses( [ request_input(0), request_input(1), @@ -703,9 +703,9 @@ def test_p2tr_with_proof(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client as client: is_t1 = session.model is models.T1B1 - session.set_expected_responses( + client.set_expected_responses( [ request_input(0), request_input(1), @@ -770,8 +770,8 @@ def test_p2wpkh_with_false_proof(session: Session): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), diff --git a/tests/device_tests/bitcoin/test_signtx_invalid_path.py b/tests/device_tests/bitcoin/test_signtx_invalid_path.py index 27f0599de9..6277633535 100644 --- a/tests/device_tests/bitcoin/test_signtx_invalid_path.py +++ b/tests/device_tests/bitcoin/test_signtx_invalid_path.py @@ -82,7 +82,7 @@ def test_invalid_path_prompt(session: Session): with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) @@ -108,7 +108,7 @@ def test_invalid_path_pass_forkid(session: Session): with session.client as client: if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) @@ -178,8 +178,8 @@ def test_attack_path_segwit(session: Session): return msg - with session: - session.set_filter(messages.TxAck, attack_processor) + with session.client as client: + client.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} @@ -202,8 +202,8 @@ def test_invalid_path_fail_asap(session: Session): script_type=messages.OutputScriptType.PAYTOWITNESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), messages.Failure(code=messages.FailureType.DataError), diff --git a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py index d3ab1cf37b..09160b7457 100644 --- a/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py +++ b/tests/device_tests/bitcoin/test_signtx_mixed_inputs.py @@ -58,7 +58,7 @@ def test_non_segwit_segwit_inputs(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client: signatures, serialized_tx = btc.sign_tx( session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) @@ -94,7 +94,7 @@ def test_segwit_non_segwit_inputs(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client: signatures, serialized_tx = btc.sign_tx( session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API ) @@ -138,7 +138,7 @@ def test_segwit_non_segwit_segwit_inputs(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client: signatures, serialized_tx = btc.sign_tx( session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) @@ -180,7 +180,7 @@ def test_non_segwit_segwit_non_segwit_inputs(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client: signatures, serialized_tx = btc.sign_tx( session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API ) diff --git a/tests/device_tests/bitcoin/test_signtx_payreq.py b/tests/device_tests/bitcoin/test_signtx_payreq.py index 32c90d05e0..1d1ead8ff4 100644 --- a/tests/device_tests/bitcoin/test_signtx_payreq.py +++ b/tests/device_tests/bitcoin/test_signtx_payreq.py @@ -204,7 +204,7 @@ def test_payment_request_details(session: Session): ] with session.client as client: - IF = InputFlowPaymentRequestDetails(client, outputs) + IF = InputFlowPaymentRequestDetails(session.client, outputs) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( diff --git a/tests/device_tests/bitcoin/test_signtx_prevhash.py b/tests/device_tests/bitcoin/test_signtx_prevhash.py index a2f96c04ed..61bcf1a499 100644 --- a/tests/device_tests/bitcoin/test_signtx_prevhash.py +++ b/tests/device_tests/bitcoin/test_signtx_prevhash.py @@ -130,8 +130,8 @@ def test_invalid_prev_hash_attack(session: Session, prev_hash): msg.tx.inputs[0].prev_hash = prev_hash return msg - with session, session.client as client, pytest.raises(TrezorFailure) as e: - session.set_filter(messages.TxAck, attack_filter) + with session.client as client, pytest.raises(TrezorFailure) as e: + client.set_filter(messages.TxAck, attack_filter) if is_core(session): IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) @@ -168,9 +168,9 @@ def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash): tx_hash = hash_tx(serialize_tx(prev_tx)) inp0.prev_hash = tx_hash - with session, session.client as client, pytest.raises(TrezorFailure) as e: + with session.client as client, pytest.raises(TrezorFailure) as e: if session.model is not models.T1B1: - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) _check_error_message(prev_hash, session.model, e.value.message) diff --git a/tests/device_tests/bitcoin/test_signtx_replacement.py b/tests/device_tests/bitcoin/test_signtx_replacement.py index fd5db6a502..590ed7e2e2 100644 --- a/tests/device_tests/bitcoin/test_signtx_replacement.py +++ b/tests/device_tests/bitcoin/test_signtx_replacement.py @@ -116,8 +116,8 @@ def test_p2pkh_fee_bump(session: Session): orig_index=1, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_meta(TXHASH_50f6f1), @@ -190,7 +190,7 @@ def test_p2wpkh_op_return_fee_bump(session: Session): orig_index=1, ) - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Testnet", @@ -243,8 +243,8 @@ def test_p2tr_fee_bump(session: Session): orig_index=1, script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_meta(TXHASH_8e4af7), @@ -312,8 +312,8 @@ def test_p2wpkh_finalize(session: Session): orig_index=1, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_meta(TXHASH_70f987), @@ -444,8 +444,8 @@ def test_p2wpkh_payjoin( orig_index=1, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_meta(TXHASH_65b768), @@ -520,8 +520,8 @@ def test_p2wpkh_in_p2sh_remove_change(session: Session): orig_index=0, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -599,8 +599,8 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(session: Session): orig_index=0, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), @@ -720,8 +720,8 @@ def test_tx_meld(session: Session): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_meta(TXHASH_334cd7), diff --git a/tests/device_tests/bitcoin/test_signtx_segwit.py b/tests/device_tests/bitcoin/test_signtx_segwit.py index ef8c988ff3..643b37082d 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit.py @@ -66,8 +66,8 @@ def test_send_p2sh(session: Session, chunkify: bool): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -124,8 +124,8 @@ def test_send_p2sh_change(session: Session): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -179,8 +179,8 @@ def test_testnet_segwit_big_amount(session: Session): amount=2**32 + 1, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -254,8 +254,8 @@ def test_send_multisig_1(session: Session): request_finished(), ] - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -265,8 +265,8 @@ def test_send_multisig_1(session: Session): # sign with third key inp1.address_n[2] = H_(3) - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -282,6 +282,7 @@ def test_attack_change_input_address(session: Session): # Simulates an attack where the user is coerced into unknowingly # transferring funds from one account to another one of their accounts, # potentially resulting in privacy issues. + client = session.client inp1 = messages.TxInputType( address_n=parse_path("m/49h/1h/0h/1/0"), @@ -303,8 +304,8 @@ def test_attack_change_input_address(session: Session): ) # Test if the transaction can be signed normally. - with session: - session.set_expected_responses( + with client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -349,8 +350,8 @@ def test_attack_change_input_address(session: Session): return msg # Now run the attack, must trigger the exception - with session: - session.set_filter(messages.TxAck, attack_processor) + with client: + client.set_filter(messages.TxAck, attack_processor) with pytest.raises(TrezorFailure): btc.sign_tx( session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET @@ -360,6 +361,7 @@ def test_attack_change_input_address(session: Session): def test_attack_mixed_inputs(session: Session): TRUE_AMOUNT = 123_456_789 FAKE_AMOUNT = 120_000_000 + client = session.client inp1 = messages.TxInputType( address_n=parse_path("m/44h/1h/0h/0/0"), @@ -421,10 +423,10 @@ def test_attack_mixed_inputs(session: Session): # T1 asks for first input for witness again expected_responses.insert(-2, request_input(0)) - with session: + with client: # Sign unmodified transaction. # "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT - session.set_expected_responses(expected_responses) + client.set_expected_responses(expected_responses) btc.sign_tx( session, "Testnet", @@ -446,8 +448,8 @@ def test_attack_mixed_inputs(session: Session): expected_responses[:4] + expected_responses[5:16] + [messages.Failure()] ) - with pytest.raises(TrezorFailure) as e, session: - session.set_expected_responses(expected_responses) + with pytest.raises(TrezorFailure) as e, client: + client.set_expected_responses(expected_responses) btc.sign_tx( session, "Testnet", diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index 920b0bf48b..96202345ad 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -82,8 +82,8 @@ def test_send_p2sh(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -137,8 +137,8 @@ def test_send_p2sh_change(session: Session): script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, amount=123_456_789 - 11_000 - 12_300_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -190,8 +190,8 @@ def test_send_native(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=100_000 - 40_000 - 10_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -244,7 +244,7 @@ def test_send_to_taproot(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, amount=10_000 - 7_000 - 200, ) - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET ) @@ -277,8 +277,8 @@ def test_send_native_change(session: Session): script_type=messages.OutputScriptType.PAYTOWITNESS, amount=100_000 - 40_000 - 10_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -344,8 +344,8 @@ def test_send_both(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -449,8 +449,8 @@ def test_send_multisig_1(session: Session): request_finished(), ] - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -460,8 +460,8 @@ def test_send_multisig_1(session: Session): # sign with third key inp1.address_n[2] = H_(3) - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -526,8 +526,8 @@ def test_send_multisig_2(session: Session): request_finished(), ] - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -537,8 +537,8 @@ def test_send_multisig_2(session: Session): # sign with first key inp1.address_n[2] = H_(1) - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET ) @@ -611,10 +611,10 @@ def test_send_multisig_3_change(session: Session): request_finished(), ] - with session, session.client as client: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET @@ -626,10 +626,10 @@ def test_send_multisig_3_change(session: Session): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with session, session.client as client: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET @@ -703,10 +703,10 @@ def test_send_multisig_4_change(session: Session): request_finished(), ] - with session, session.client as client: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) signatures, _ = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET @@ -718,10 +718,10 @@ def test_send_multisig_4_change(session: Session): inp1.address_n[2] = H_(3) out1.address_n[2] = H_(3) - with session, session.client as client: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) if is_core(session): - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) _, serialized_tx = btc.sign_tx( session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET @@ -788,8 +788,8 @@ def test_multisig_mismatch_inputs_single(session: Session): amount=100_000 + 100_000 - 50_000 - 10_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), diff --git a/tests/device_tests/bitcoin/test_signtx_taproot.py b/tests/device_tests/bitcoin/test_signtx_taproot.py index 0453474af9..bc20700100 100644 --- a/tests/device_tests/bitcoin/test_signtx_taproot.py +++ b/tests/device_tests/bitcoin/test_signtx_taproot.py @@ -79,8 +79,8 @@ def test_send_p2tr(session: Session, chunkify: bool): amount=4_450, script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -133,8 +133,8 @@ def test_send_two_with_change(session: Session): script_type=messages.OutputScriptType.PAYTOTAPROOT, amount=6_800 + 13_000 - 200 - 15_000, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -222,8 +222,8 @@ def test_send_mixed(session: Session): script_type=messages.OutputScriptType.PAYTOTAPROOT, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ # process inputs request_input(0), @@ -353,9 +353,9 @@ def test_attack_script_type(session: Session): return msg - with session: - session.set_filter(messages.TxAck, attack_processor) - session.set_expected_responses( + with session.client as client: + client.set_filter(messages.TxAck, attack_processor) + client.set_expected_responses( [ request_input(0), request_input(1), @@ -406,8 +406,8 @@ def test_send_invalid_address(session: Session, address: str): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, pytest.raises(TrezorFailure): - session.set_expected_responses( + with session.client as client, pytest.raises(TrezorFailure): + client.set_expected_responses( [ request_input(0), request_output(0), diff --git a/tests/device_tests/bitcoin/test_verifymessage.py b/tests/device_tests/bitcoin/test_verifymessage.py index 36b7cc31f0..f442a94016 100644 --- a/tests/device_tests/bitcoin/test_verifymessage.py +++ b/tests/device_tests/bitcoin/test_verifymessage.py @@ -41,7 +41,7 @@ def test_message_long_legacy(session: Session): @pytest.mark.models("core") def test_message_long_core(session: Session): with session.client as client: - IF = InputFlowSignVerifyMessageLong(client, verify=True) + IF = InputFlowSignVerifyMessageLong(session.client, verify=True) client.set_input_flow(IF.get()) ret = btc.verify_message( session, diff --git a/tests/device_tests/bitcoin/test_zcash.py b/tests/device_tests/bitcoin/test_zcash.py index adb9958915..b798bc8b13 100644 --- a/tests/device_tests/bitcoin/test_zcash.py +++ b/tests/device_tests/bitcoin/test_zcash.py @@ -75,7 +75,7 @@ def test_v3_not_supported(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session, pytest.raises(TrezorFailure, match="DataError"): + with session.client, pytest.raises(TrezorFailure, match="DataError"): btc.sign_tx( session, "Zcash Testnet", @@ -106,8 +106,8 @@ def test_one_one_fee_sapling(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -210,7 +210,7 @@ def test_spend_old_versions(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: + with session.client: _, serialized_tx = btc.sign_tx( session, "Zcash Testnet", @@ -259,8 +259,8 @@ def test_external_presigned(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), diff --git a/tests/device_tests/cardano/test_address_public_key.py b/tests/device_tests/cardano/test_address_public_key.py index bdc68bd065..30ea092548 100644 --- a/tests/device_tests/cardano/test_address_public_key.py +++ b/tests/device_tests/cardano/test_address_public_key.py @@ -95,9 +95,9 @@ def test_cardano_get_address(session: Session, chunkify: bool, parameters, resul "cardano/get_public_key.derivations.json", ) def test_cardano_get_public_key(session: Session, parameters, result): - with session: + with session.client as client: IF = InputFlowShowXpubQRCode(session.client, passphrase_request_expected=False) - session.set_input_flow(IF.get()) + client.set_input_flow(IF.get()) # session.init_device(new_session=True, derive_cardano=True) derivation_type = CardanoDerivationType.__members__[ diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index 362a1793ce..fc5d6d0e98 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -63,7 +63,7 @@ def test_cardano_sign_tx(session: Session, parameters, result): response = call_sign_tx( session, parameters, - input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), + input_flow=lambda client: InputFlowConfirmAllWarnings(session.client).get(), ) assert response == _transform_expected_result(result) @@ -124,8 +124,8 @@ def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool = with session.client as client: if input_flow is not None: - client.watch_layout() - client.set_input_flow(input_flow(client)) + session.client.watch_layout() + client.set_input_flow(input_flow(session.client)) return cardano.sign_tx( session=session, diff --git a/tests/device_tests/eos/test_get_public_key.py b/tests/device_tests/eos/test_get_public_key.py index d99c54cb2b..7d0a0de156 100644 --- a/tests/device_tests/eos/test_get_public_key.py +++ b/tests/device_tests/eos/test_get_public_key.py @@ -30,7 +30,7 @@ from ...input_flows import InputFlowShowXpubQRCode @pytest.mark.setup_client(mnemonic=MNEMONIC12) def test_eos_get_public_key(session: Session): with session.client as client: - IF = InputFlowShowXpubQRCode(client) + IF = InputFlowShowXpubQRCode(session.client) client.set_input_flow(IF.get()) public_key = get_public_key( session, parse_path("m/44h/194h/0h/0/0"), show_display=True diff --git a/tests/device_tests/eos/test_signtx.py b/tests/device_tests/eos/test_signtx.py index 54ebece6a9..d5eac0347d 100644 --- a/tests/device_tests/eos/test_signtx.py +++ b/tests/device_tests/eos/test_signtx.py @@ -60,7 +60,7 @@ def test_eos_signtx_transfer_token(session: Session, chunkify: bool): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) assert isinstance(resp, EosSignedTx) assert ( @@ -93,7 +93,7 @@ def test_eos_signtx_buyram(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -126,7 +126,7 @@ def test_eos_signtx_buyrambytes(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -155,7 +155,7 @@ def test_eos_signtx_sellram(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -190,7 +190,7 @@ def test_eos_signtx_delegate(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -224,7 +224,7 @@ def test_eos_signtx_undelegate(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -253,7 +253,7 @@ def test_eos_signtx_refund(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -287,7 +287,7 @@ def test_eos_signtx_linkauth(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -320,7 +320,7 @@ def test_eos_signtx_unlinkauth(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -376,7 +376,7 @@ def test_eos_signtx_updateauth(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -405,7 +405,7 @@ def test_eos_signtx_deleteauth(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -468,7 +468,7 @@ def test_eos_signtx_vote(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -497,7 +497,7 @@ def test_eos_signtx_vote_proxy(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -526,7 +526,7 @@ def test_eos_signtx_unknown(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -602,7 +602,7 @@ def test_eos_signtx_newaccount(session: Session): "transaction_extensions": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( @@ -638,7 +638,7 @@ def test_eos_signtx_setcontract(session: Session): "context_free_data": [], } - with session: + with session.client: resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) assert isinstance(resp, EosSignedTx) assert ( diff --git a/tests/device_tests/ethereum/test_definitions.py b/tests/device_tests/ethereum/test_definitions.py index 9cc3fd5704..9df0ed96af 100644 --- a/tests/device_tests/ethereum/test_definitions.py +++ b/tests/device_tests/ethereum/test_definitions.py @@ -123,9 +123,9 @@ def test_external_token(session: Session) -> None: def test_external_chain_without_token(session: Session) -> None: - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session.client as client: + if not session.client.debug.legacy_debug: + client.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) # when using an external chains, unknown tokens are allowed network = common.encode_network(chain_id=66666, slip44=60) params = DEFAULT_ERC20_PARAMS.copy() @@ -145,9 +145,9 @@ def test_external_chain_token_ok(session: Session) -> None: def test_external_chain_token_mismatch(session: Session) -> None: - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session.client as client: + if not session.client.debug.legacy_debug: + client.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) # when providing external defs, we explicitly allow, but not use, tokens # from other chains network = common.encode_network(chain_id=66666, slip44=60) diff --git a/tests/device_tests/ethereum/test_getaddress.py b/tests/device_tests/ethereum/test_getaddress.py index b57fcd6afd..c35049c354 100644 --- a/tests/device_tests/ethereum/test_getaddress.py +++ b/tests/device_tests/ethereum/test_getaddress.py @@ -38,7 +38,7 @@ def test_getaddress(session: Session, parameters, result): @parametrize_using_common_fixtures("ethereum/getaddress.json") def test_getaddress_chunkify_details(session: Session, parameters, result): with session.client as client: - IF = InputFlowShowAddressQRCode(client) + IF = InputFlowShowAddressQRCode(session.client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) assert ( diff --git a/tests/device_tests/ethereum/test_sign_typed_data.py b/tests/device_tests/ethereum/test_sign_typed_data.py index fdfde6df7e..951c5124a5 100644 --- a/tests/device_tests/ethereum/test_sign_typed_data.py +++ b/tests/device_tests/ethereum/test_sign_typed_data.py @@ -29,7 +29,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum] @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") def test_ethereum_sign_typed_data(session: Session, parameters, result): - with session: + with session.client: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data( session, @@ -44,7 +44,7 @@ def test_ethereum_sign_typed_data(session: Session, parameters, result): @pytest.mark.models("legacy") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json") def test_ethereum_sign_typed_data_blind(session: Session, parameters, result): - with session: + with session.client: address_n = parse_path(parameters["path"]) ret = ethereum.sign_typed_data_hash( session, @@ -112,8 +112,8 @@ def test_ethereum_sign_typed_data_show_more_button(session: Session): @pytest.mark.models("core") def test_ethereum_sign_typed_data_cancel(session: Session): with session.client as client, pytest.raises(exceptions.Cancelled): - client.watch_layout() - IF = InputFlowEIP712Cancel(client) + session.client.watch_layout() + IF = InputFlowEIP712Cancel(session.client) client.set_input_flow(IF.get()) ethereum.sign_typed_data( session, diff --git a/tests/device_tests/ethereum/test_sign_verify_message.py b/tests/device_tests/ethereum/test_sign_verify_message.py index c3ef56984c..3d9052e3b6 100644 --- a/tests/device_tests/ethereum/test_sign_verify_message.py +++ b/tests/device_tests/ethereum/test_sign_verify_message.py @@ -37,7 +37,7 @@ def test_signmessage(session: Session, parameters, result): assert res.signature.hex() == result["sig"] else: with session.client as client: - IF = InputFlowSignVerifyMessageLong(client) + IF = InputFlowSignVerifyMessageLong(session.client) client.set_input_flow(IF.get()) res = ethereum.sign_message( session, parse_path(parameters["path"]), parameters["msg"] @@ -58,7 +58,7 @@ def test_verify(session: Session, parameters, result): assert res is True else: with session.client as client: - IF = InputFlowSignVerifyMessageLong(client, verify=True) + IF = InputFlowSignVerifyMessageLong(session.client, verify=True) client.set_input_flow(IF.get()) res = ethereum.verify_message( session, diff --git a/tests/device_tests/ethereum/test_signtx.py b/tests/device_tests/ethereum/test_signtx.py index bdbbf39671..3c69788942 100644 --- a/tests/device_tests/ethereum/test_signtx.py +++ b/tests/device_tests/ethereum/test_signtx.py @@ -147,9 +147,9 @@ def test_signtx_go_back_from_summary(session: Session): def test_signtx_eip1559( session: Session, chunkify: bool, parameters: dict, result: dict ): - with session, session.client as client: - if not client.debug.legacy_debug: - client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) + with session.client as client: + if not session.client.debug.legacy_debug: + client.set_input_flow(InputFlowConfirmAllWarnings(session.client).get()) sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( session, n=parse_path(parameters["path"]), @@ -218,8 +218,8 @@ def test_data_streaming(session: Session): """Only verifying the expected responses, the signatures are checked in vectorized function above. """ - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), @@ -266,7 +266,7 @@ def test_data_streaming(session: Session): def test_signtx_eip1559_access_list(session: Session): - with session: + with session.client: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( session, @@ -305,7 +305,7 @@ def test_signtx_eip1559_access_list(session: Session): def test_signtx_eip1559_access_list_larger(session: Session): - with session: + with session.client: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( session, @@ -438,6 +438,8 @@ HEXDATA = "0123456789abcd000023456789abcd010003456789abcd020000456789abcd0300000 ) @pytest.mark.models("core") def test_signtx_data_pagination(session: Session, flow): + client = session.client + def _sign_tx_call(): ethereum.sign_tx( session, @@ -452,15 +454,15 @@ def test_signtx_data_pagination(session: Session, flow): data=bytes.fromhex(HEXDATA), ) - with session, session.client as client: + with client: client.watch_layout() - client.set_input_flow(flow(client)) + client.set_input_flow(flow(session.client)) _sign_tx_call() if flow is not input_flow_data_scroll_down: - with session, session.client as client, pytest.raises(exceptions.Cancelled): + with client, pytest.raises(exceptions.Cancelled): client.watch_layout() - client.set_input_flow(flow(client, cancel=True)) + client.set_input_flow(flow(session.client, cancel=True)) _sign_tx_call() @@ -500,7 +502,7 @@ def test_signtx_staking_bad_inputs(session: Session, parameters: dict, result: d @pytest.mark.models("core") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json") def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict): - with session: + with session.client: sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( session, n=parse_path(parameters["path"]), diff --git a/tests/device_tests/misc/test_msg_enablelabeling.py b/tests/device_tests/misc/test_msg_enablelabeling.py index e1c0300191..07ba5b1c79 100644 --- a/tests/device_tests/misc/test_msg_enablelabeling.py +++ b/tests/device_tests/misc/test_msg_enablelabeling.py @@ -33,7 +33,7 @@ def test_encrypt(client: Client): client.debug.press_yes() session = client.get_session() - with client, session: + with session.client as client: client.set_input_flow(input_flow()) misc.encrypt_keyvalue( session, diff --git a/tests/device_tests/misc/test_msg_getentropy.py b/tests/device_tests/misc/test_msg_getentropy.py index d5d19425f9..64c3abeb3e 100644 --- a/tests/device_tests/misc/test_msg_getentropy.py +++ b/tests/device_tests/misc/test_msg_getentropy.py @@ -41,8 +41,8 @@ def entropy(data): @pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS) def test_entropy(session: Session, entropy_length): - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy] ) ent = misc.get_entropy(session, entropy_length) diff --git a/tests/device_tests/monero/test_getaddress.py b/tests/device_tests/monero/test_getaddress.py index 1a6d3ffc01..9b363ed8ec 100644 --- a/tests/device_tests/monero/test_getaddress.py +++ b/tests/device_tests/monero/test_getaddress.py @@ -57,7 +57,7 @@ def test_monero_getaddress_chunkify_details( session: Session, path: str, expected_address: bytes ): with session.client as client: - IF = InputFlowShowAddressQRCode(client) + IF = InputFlowShowAddressQRCode(session.client) client.set_input_flow(IF.get()) address = monero.get_address( session, parse_path(path), show_display=True, chunkify=True diff --git a/tests/device_tests/nem/test_signtx_others.py b/tests/device_tests/nem/test_signtx_others.py index 9760d8c523..61215cd256 100644 --- a/tests/device_tests/nem/test_signtx_others.py +++ b/tests/device_tests/nem/test_signtx_others.py @@ -32,7 +32,7 @@ pytestmark = [ # assertion data from T1 def test_nem_signtx_importance_transfer(session: Session): - with session: + with session.client: tx = nem.sign_tx( session, parse_path("m/44h/1h/0h/0h/0h"), diff --git a/tests/device_tests/nem/test_signtx_transfers.py b/tests/device_tests/nem/test_signtx_transfers.py index 2df62b5593..15b4250cd0 100644 --- a/tests/device_tests/nem/test_signtx_transfers.py +++ b/tests/device_tests/nem/test_signtx_transfers.py @@ -33,8 +33,8 @@ pytestmark = [ # assertion data from T1 @pytest.mark.parametrize("chunkify", (True, False)) def test_nem_signtx_simple(session: Session, chunkify: bool): - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), @@ -83,8 +83,8 @@ def test_nem_signtx_simple(session: Session, chunkify: bool): @pytest.mark.setup_client(mnemonic=MNEMONIC12) def test_nem_signtx_encrypted_payload(session: Session): - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ # Confirm transfer and network fee messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py index 8841a52426..bb5594ded9 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_dryrun.py @@ -52,8 +52,8 @@ def do_recover_legacy(session: Session, mnemonic: list[str]): def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): with session.client as client: - client.watch_layout() - IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) + session.client.watch_layout() + IF = InputFlowBip39RecoveryDryRun(session.client, mnemonic, mismatch=mismatch) client.set_input_flow(IF.get()) return device.recover(session, type=messages.RecoveryType.DryRun) @@ -87,8 +87,8 @@ def test_invalid_seed_t1(session: Session): @pytest.mark.models("core") def test_invalid_seed_core(session: Session): - with session, session.client as client: - client.watch_layout() + with session.client as client: + session.client.watch_layout() IF = InputFlowBip39RecoveryDryRunInvalid(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): diff --git a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py index abca75bbee..f5958f4c5b 100644 --- a/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_recovery_bip39_t2.py @@ -29,7 +29,7 @@ pytestmark = pytest.mark.models("core") @pytest.mark.uninitialized_session def test_tt_pin_passphrase(session: Session): with session.client as client: - IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") + IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "), pin="654") client.set_input_flow(IF.get()) device.recover( session, @@ -50,7 +50,7 @@ def test_tt_pin_passphrase(session: Session): @pytest.mark.uninitialized_session def test_tt_nopin_nopassphrase(session: Session): with session.client as client: - IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) + IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" ")) client.set_input_flow(IF.get()) device.recover( session, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py index 3eb0c4d265..8dffffe97d 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced.py @@ -49,7 +49,9 @@ def _test_secret( session: Session, shares: list[str], secret: str, click_info: bool = False ): with session.client as client: - IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) + IF = InputFlowSlip39AdvancedRecovery( + session.client, shares, click_info=click_info + ) client.set_input_flow(IF.get()) device.recover( session, @@ -90,7 +92,7 @@ def test_extra_share_entered(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort(session: Session): with session.client as client: - IF = InputFlowSlip39AdvancedRecoveryAbort(client) + IF = InputFlowSlip39AdvancedRecoveryAbort(session.client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @@ -102,7 +104,7 @@ def test_abort(session: Session): def test_noabort(session: Session): with session.client as client: IF = InputFlowSlip39AdvancedRecoveryNoAbort( - client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 + session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") @@ -118,7 +120,7 @@ def test_same_share(session: Session): # second share is first 4 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] - with session, session.client as client: + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( session, first_share, second_share ) @@ -134,7 +136,7 @@ def test_group_threshold_reached(session: Session): # second share is first 3 words of first second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] - with session, session.client as client: + with session.client as client: IF = InputFlowSlip39AdvancedRecoveryThresholdReached( session, first_share, second_share ) diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py index 37b4a0264d..7644951023 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_advanced_dryrun.py @@ -42,7 +42,7 @@ EXTRA_GROUP_SHARE = [ def test_2of3_dryrun(session: Session): with session.client as client: IF = InputFlowSlip39AdvancedRecoveryDryRun( - client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 + session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 ) client.set_input_flow(IF.get()) device.recover( @@ -61,7 +61,7 @@ def test_2of3_invalid_seed_dryrun(session: Session): TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39AdvancedRecoveryDryRun( - client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True + session.client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True ) client.set_input_flow(IF.get()) device.recover( diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py index 1a20899279..82bd4e57f4 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic.py @@ -74,7 +74,7 @@ def test_secret( session: Session, shares: list[str], secret: str, backup_type: messages.BackupType ): with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, shares) + IF = InputFlowSlip39BasicRecovery(session.client, shares) client.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") @@ -91,7 +91,7 @@ def test_secret( def test_recover_with_pin_passphrase(session: Session): with session.client as client: IF = InputFlowSlip39BasicRecovery( - client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" + session.client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" ) client.set_input_flow(IF.get()) device.recover( @@ -110,7 +110,7 @@ def test_recover_with_pin_passphrase(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_abort(session: Session): with session.client as client: - IF = InputFlowSlip39BasicRecoveryAbort(client) + IF = InputFlowSlip39BasicRecoveryAbort(session.client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @@ -124,7 +124,7 @@ def test_abort(session: Session): def test_abort_on_number_of_words(session: Session): # on Caesar, test_abort actually aborts on the # of words selection with session.client as client: - IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client) + IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(session.client) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): device.recover(session, pin_protection=False, label="label") @@ -136,7 +136,7 @@ def test_abort_on_number_of_words(session: Session): def test_abort_between_shares(session: Session): with session.client as client: IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( - client, MNEMONIC_SLIP39_BASIC_20_3of6 + session.client, MNEMONIC_SLIP39_BASIC_20_3of6 ) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): @@ -149,7 +149,9 @@ def test_abort_between_shares(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_noabort(session: Session): with session.client as client: - IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) + IF = InputFlowSlip39BasicRecoveryNoAbort( + session.client, MNEMONIC_SLIP39_BASIC_20_3of6 + ) client.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") session.refresh_features() @@ -158,7 +160,7 @@ def test_noabort(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_invalid_mnemonic_first_share(session: Session): - with session, session.client as client: + with session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): @@ -169,7 +171,7 @@ def test_invalid_mnemonic_first_share(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_invalid_mnemonic_second_share(session: Session): - with session, session.client as client: + with session.client as client: IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( session, MNEMONIC_SLIP39_BASIC_20_3of6 ) @@ -184,7 +186,7 @@ def test_invalid_mnemonic_second_share(session: Session): @pytest.mark.parametrize("nth_word", range(3)) def test_wrong_nth_word(session: Session, nth_word: int): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with session, session.client as client: + with session.client as client: IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): @@ -194,7 +196,7 @@ def test_wrong_nth_word(session: Session, nth_word: int): @pytest.mark.setup_client(uninitialized=True) def test_same_share(session: Session): share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") - with session, session.client as client: + with session.client as client: IF = InputFlowSlip39BasicRecoverySameShare(session, share) client.set_input_flow(IF.get()) with pytest.raises(exceptions.Cancelled): @@ -204,7 +206,7 @@ def test_same_share(session: Session): @pytest.mark.setup_client(uninitialized=True) def test_1of1(session: Session): with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) + IF = InputFlowSlip39BasicRecovery(session.client, MNEMONIC_SLIP39_BASIC_20_1of1) client.set_input_flow(IF.get()) device.recover( session, diff --git a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py index b9c4ca6daa..e40d8ee41a 100644 --- a/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py +++ b/tests/device_tests/reset_recovery/test_recovery_slip39_basic_dryrun.py @@ -39,7 +39,7 @@ INVALID_SHARES_20_2of3 = [ @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) def test_2of3_dryrun(session: Session): with session.client as client: - IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) + IF = InputFlowSlip39BasicRecoveryDryRun(session.client, SHARES_20_2of3[1:3]) client.set_input_flow(IF.get()) device.recover( session, @@ -57,7 +57,7 @@ def test_2of3_invalid_seed_dryrun(session: Session): TrezorFailure, match=r"The seed does not match the one in the device" ): IF = InputFlowSlip39BasicRecoveryDryRun( - client, INVALID_SHARES_20_2of3, mismatch=True + session.client, INVALID_SHARES_20_2of3, mismatch=True ) client.set_input_flow(IF.get()) device.recover( diff --git a/tests/device_tests/reset_recovery/test_reset_backup.py b/tests/device_tests/reset_recovery/test_reset_backup.py index 9710ee6201..782b58ea96 100644 --- a/tests/device_tests/reset_recovery/test_reset_backup.py +++ b/tests/device_tests/reset_recovery/test_reset_backup.py @@ -78,7 +78,7 @@ VECTORS = [ def test_skip_backup_msg(session: Session, backup_type, backup_flow): assert session.features.initialized is False - with session: + with session.client: device.setup( session, skip_backup=True, @@ -116,7 +116,7 @@ def test_skip_backup_msg(session: Session, backup_type, backup_flow): def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): assert session.features.initialized is False - with session, session.client as client: + with session.client as client: IF = InputFlowResetSkipBackup(client) client.set_input_flow(IF.get()) device.setup( diff --git a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py index 65dc8a4e6e..b29d7c3d93 100644 --- a/tests/device_tests/reset_recovery/test_reset_bip39_t2.py +++ b/tests/device_tests/reset_recovery/test_reset_bip39_t2.py @@ -36,7 +36,7 @@ pytestmark = pytest.mark.models("core") def reset_device(session: Session, strength: int): debug = session.client.debug with session.client as client: - IF = InputFlowBip39ResetBackup(client) + IF = InputFlowBip39ResetBackup(session.client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random @@ -92,7 +92,7 @@ def test_reset_device_pin(session: Session): strength = 256 # 24 words with session.client as client: - IF = InputFlowBip39ResetPIN(client) + IF = InputFlowBip39ResetPIN(session.client) client.set_input_flow(IF.get()) # PIN, passphrase, display random @@ -130,7 +130,7 @@ def test_reset_entropy_check(session: Session): strength = 128 # 12 words with session.client as client: - IF = InputFlowBip39ResetBackup(client) + IF = InputFlowBip39ResetBackup(session.client) client.set_input_flow(IF.get()) # No PIN, no passphrase @@ -146,7 +146,7 @@ def test_reset_entropy_check(session: Session): ) # Generate the mnemonic locally. - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) @@ -177,7 +177,7 @@ def test_reset_failed_check(session: Session): strength = 256 # 24 words with session.client as client: - IF = InputFlowBip39ResetFailedCheck(client) + IF = InputFlowBip39ResetFailedCheck(session.client) client.set_input_flow(IF.get()) # PIN, passphrase, display random @@ -263,9 +263,9 @@ def test_already_initialized(session: Session): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.uninitialized_session def test_entropy_check(session: Session): - with session: + with session.client as client: delizia = session.client.debug.layout_type is LayoutType.Delizia - session.set_expected_responses( + client.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), @@ -300,9 +300,9 @@ def test_entropy_check(session: Session): @pytest.mark.setup_client(uninitialized=True) @pytest.mark.uninitialized_session def test_no_entropy_check(session: Session): - with session: + with session.client as client: delizia = session.client.debug.layout_type is LayoutType.Delizia - session.set_expected_responses( + client.set_expected_responses( [ messages.ButtonRequest(name="setup_device"), (delizia, messages.ButtonRequest(name="confirm_setup_device")), diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py index e1ceacbb32..3eab1d811f 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_bip39.py @@ -48,7 +48,7 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: with session.client as client: - IF = InputFlowBip39ResetBackup(client) + IF = InputFlowBip39ResetBackup(session.client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random @@ -78,9 +78,9 @@ def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> s def recover(session: Session, mnemonic: str): words = mnemonic.split(" ") with session.client as client: - IF = InputFlowBip39Recovery(client, words) + IF = InputFlowBip39Recovery(session.client, words) client.set_input_flow(IF.get()) - client.watch_layout() + session.client.watch_layout() device.recover(session, pin_protection=False, label="label") # Workflow successfully ended diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py index e98ad6983c..a46e8c2806 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_advanced.py @@ -69,7 +69,7 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128) -> list[str]: with session.client as client: - IF = InputFlowSlip39AdvancedResetRecovery(client, False) + IF = InputFlowSlip39AdvancedResetRecovery(session.client, False) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random diff --git a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py index 8e4e53fe47..2a81df9258 100644 --- a/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_recovery_slip39_basic.py @@ -59,7 +59,7 @@ def test_reset_recovery(client: Client): def reset(session: Session, strength: int = 128) -> list[str]: with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) + IF = InputFlowSlip39BasicResetRecovery(session.client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random @@ -88,7 +88,7 @@ def reset(session: Session, strength: int = 128) -> list[str]: def recover(session: Session, shares: t.Sequence[str]): with session.client as client: - IF = InputFlowSlip39BasicRecovery(client, shares) + IF = InputFlowSlip39BasicRecovery(session.client, shares) client.set_input_flow(IF.get()) device.recover(session, pin_protection=False, label="label") diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py index 2d5c9edd4a..de4f75f43f 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_advanced.py @@ -34,10 +34,10 @@ def test_reset_device_slip39_advanced(client: Client): strength = 128 member_threshold = 3 - with client: + session = client.get_seedless_session() + with session.client as client: IF = InputFlowSlip39AdvancedResetRecovery(client, False) client.set_input_flow(IF.get()) - session = client.get_seedless_session() # No PIN, no passphrase, don't display random device.setup( session, diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index dd25fc1342..989cf7f5d1 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -35,7 +35,7 @@ def reset_device(session: Session, strength: int): member_threshold = 3 with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) + IF = InputFlowSlip39BasicResetRecovery(session.client) client.set_input_flow(IF.get()) # No PIN, no passphrase, don't display random @@ -90,7 +90,7 @@ def test_reset_entropy_check(session: Session): strength = 128 # 20 words with session.client as client: - IF = InputFlowSlip39BasicResetRecovery(client) + IF = InputFlowSlip39BasicResetRecovery(session.client) client.set_input_flow(IF.get()) # No PIN, no passphrase. diff --git a/tests/device_tests/ripple/test_get_address.py b/tests/device_tests/ripple/test_get_address.py index 2a066926cd..e95b4683a3 100644 --- a/tests/device_tests/ripple/test_get_address.py +++ b/tests/device_tests/ripple/test_get_address.py @@ -53,7 +53,7 @@ def test_ripple_get_address_chunkify_details( session: Session, path: str, expected_address: str ): with session.client as client: - IF = InputFlowShowAddressQRCode(client) + IF = InputFlowShowAddressQRCode(session.client) client.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True diff --git a/tests/device_tests/solana/test_sign_tx.py b/tests/device_tests/solana/test_sign_tx.py index 708ccdd69f..f42ee5135c 100644 --- a/tests/device_tests/solana/test_sign_tx.py +++ b/tests/device_tests/solana/test_sign_tx.py @@ -48,7 +48,7 @@ def test_solana_sign_tx(session: Session, parameters, result): serialized_tx = _serialize_tx(parameters["construct"]) with session.client as client: - IF = InputFlowConfirmAllWarnings(client) + IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) actual_result = sign_tx( session, diff --git a/tests/device_tests/stellar/test_stellar.py b/tests/device_tests/stellar/test_stellar.py index 1d5c59e1f8..e54e89f445 100644 --- a/tests/device_tests/stellar/test_stellar.py +++ b/tests/device_tests/stellar/test_stellar.py @@ -123,7 +123,7 @@ def test_get_address(session: Session, parameters, result): @parametrize_using_common_fixtures("stellar/get_address.json") def test_get_address_chunkify_details(session: Session, parameters, result): with session.client as client: - IF = InputFlowShowAddressQRCode(client) + IF = InputFlowShowAddressQRCode(session.client) client.set_input_flow(IF.get()) address_n = parse_path(parameters["path"]) address = stellar.get_address( diff --git a/tests/device_tests/test_autolock.py b/tests/device_tests/test_autolock.py index 423cbc5378..e075ef057c 100644 --- a/tests/device_tests/test_autolock.py +++ b/tests/device_tests/test_autolock.py @@ -38,9 +38,9 @@ def pin_request(session: Session): def set_autolock_delay(session: Session, delay): - with session, session.client as client: + with session.client as client: client.use_pin_sequence([PIN4]) - session.set_expected_responses( + client.set_expected_responses( [ pin_request(session), messages.ButtonRequest, @@ -52,18 +52,19 @@ def set_autolock_delay(session: Session, delay): def test_apply_auto_lock_delay(session: Session): + client = session.client set_autolock_delay(session, 10 * 1000) time.sleep(0.1) # sleep less than auto-lock delay - with session: + with client: # No PIN protection is required. - session.set_expected_responses([messages.Address]) + client.set_expected_responses([messages.Address]) get_test_address(session) time.sleep(10.5) # sleep more than auto-lock delay - with session, session.client as client: + with client: client.use_pin_sequence([PIN4]) - session.set_expected_responses([pin_request(session), messages.Address]) + client.set_expected_responses([pin_request(session), messages.Address]) get_test_address(session) @@ -85,7 +86,7 @@ def test_apply_auto_lock_delay_valid(session: Session, seconds): def test_autolock_default_value(session: Session): assert session.features.auto_lock_delay_ms is None - with session, session.client as client: + with session.client as client: client.use_pin_sequence([PIN4]) device.apply_settings(session, label="pls unlock") session.refresh_features() @@ -98,9 +99,9 @@ def test_autolock_default_value(session: Session): ) def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): - with session, session.client as client: - client.use_pin_sequence([PIN4]) - session.set_expected_responses( + with session.client as client: + session.client.use_pin_sequence([PIN4]) + client.set_expected_responses( [ pin_request(session), messages.Failure(code=messages.FailureType.ProcessError), diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index a7fa64a454..5fa14ceb84 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -40,8 +40,8 @@ def test_cancel_message_via_cancel(session: Session, message): yield session.cancel() - with session, session.client as client, pytest.raises(Cancelled): - session.set_expected_responses([m.ButtonRequest(), m.Failure()]) + with session.client as client, pytest.raises(Cancelled): + client.set_expected_responses([m.ButtonRequest(), m.Failure()]) client.set_input_flow(input_flow) session.call(message) diff --git a/tests/device_tests/test_language.py b/tests/device_tests/test_language.py index 0fe6e27595..dc90681a7b 100644 --- a/tests/device_tests/test_language.py +++ b/tests/device_tests/test_language.py @@ -79,7 +79,7 @@ def _check_ping_screen_texts(session: Session, title: str, right_button: str) -> if session.model in (models.T2T1, models.T3T1): right_button = "-" - with session, session.client as client: + with session.client as client: client.watch_layout(True) client.set_input_flow(ping_input_flow(session, title, right_button)) ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) @@ -93,7 +93,7 @@ def test_error_too_long(session: Session): max_length = MAX_DATA_LENGTH[session.model] with pytest.raises( exceptions.TrezorFailure, match="Translations too long" - ), session: + ), session.client: bad_data = (max_length + 1) * b"a" device.change_language(session, language_data=bad_data) assert session.features.language == "en-US" @@ -104,7 +104,9 @@ def test_error_invalid_data_length(session: Session): assert session.features.language == "en-US" # Invalid data length # Sending more data than advertised in the header - with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), session: + with pytest.raises( + exceptions.TrezorFailure, match="Invalid data length" + ), session.client: good_data = build_and_sign_blob("cs", session) bad_data = good_data + b"abcd" device.change_language(session, language_data=bad_data) @@ -118,7 +120,7 @@ def test_error_invalid_header_magic(session: Session): # Does not match the expected magic with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), session: + ), session.client: good_data = build_and_sign_blob("cs", session) bad_data = 4 * b"a" + good_data[4:] device.change_language(session, language_data=bad_data) @@ -132,7 +134,7 @@ def test_error_invalid_data_hash(session: Session): # Changing the data after their hash has been calculated with pytest.raises( exceptions.TrezorFailure, match="Translation data verification failed" - ), session: + ), session.client: good_data = build_and_sign_blob("cs", session) bad_data = good_data[:-8] + 8 * b"a" device.change_language( @@ -149,7 +151,7 @@ def test_error_version_mismatch(session: Session): # Change the version to one not matching the current device with pytest.raises( exceptions.TrezorFailure, match="Translations version mismatch" - ), session: + ), session.client: blob = prepare_blob("cs", session.model, (3, 5, 4, 0)) device.change_language( session, @@ -165,7 +167,7 @@ def test_error_invalid_signature(session: Session): # Changing the data in the signature section with pytest.raises( exceptions.TrezorFailure, match="Invalid translations data" - ), session: + ), session.client: blob = prepare_blob("cs", session.model, session.version) blob.proof = translations.Proof( merkle_proof=[], @@ -274,7 +276,7 @@ def test_reject_update(session: Session): yield session.client.debug.press_no() - with pytest.raises(exceptions.Cancelled), session, session.client as client: + with pytest.raises(exceptions.Cancelled), session.client as client: client.set_input_flow(input_flow_reject) device.change_language(session, language_data) @@ -311,8 +313,8 @@ def _maybe_confirm_set_language( else: expected_responses = expected_responses_silent - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) device.change_language(session, language_data, show_display=show_display) assert session.features.language is not None assert session.features.language[:2] == lang @@ -320,9 +322,9 @@ def _maybe_confirm_set_language( # explicitly handle the cases when expected_responses are correct for # change_language but incorrect for selected is_displayed mode (otherwise the # user would get an unhelpful generic expected_responses mismatch) - if is_displayed and session.actual_responses == expected_responses_silent: + if is_displayed and client.actual_responses == expected_responses_silent: raise AssertionError("Change should have been visible but was silent") - if not is_displayed and session.actual_responses == expected_responses_confirm: + if not is_displayed and client.actual_responses == expected_responses_confirm: raise AssertionError("Change should have been silent but was visible") # if the expected_responses do not match either, the generic error message will # be raised by the session context manager diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 40c18d2cab..9a000ea295 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -20,6 +20,7 @@ import pytest from trezorlib import btc, device, exceptions, messages, misc, models from trezorlib.debuglink import SessionDebugWrapper as Session +from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path from ..input_flows import InputFlowConfirmAllWarnings @@ -50,19 +51,19 @@ T1_HOMESCREEN = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9: 10_000 - with client: - client.use_pin_sequence([PIN4, PIN4]) - device.setup( - session, - skip_backup=True, - pin_protection=True, - passphrase_protection=False, - entropy_check_count=0, - backup_type=messages.BackupType.Bip39, - ) + client.use_pin_sequence([PIN4, PIN4]) + device.setup( + session, + skip_backup=True, + pin_protection=True, + passphrase_protection=False, + entropy_check_count=0, + backup_type=messages.BackupType.Bip39, + ) time.sleep(10.5) session = client.get_session() - with session, client: + with session.client as client: # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked - session.set_expected_responses([messages.Address]) + client.set_expected_responses([messages.Address]) get_test_address(session) diff --git a/tests/device_tests/test_pin.py b/tests/device_tests/test_pin.py index c911dfee50..4f5c3a9d3f 100644 --- a/tests/device_tests/test_pin.py +++ b/tests/device_tests/test_pin.py @@ -33,17 +33,17 @@ pytestmark = pytest.mark.setup_client(pin=PIN4) @pytest.mark.setup_client(pin=None) def test_no_protection(session: Session): - with session: - session.set_expected_responses([messages.Address]) + with session.client as client: + client.set_expected_responses([messages.Address]) get_test_address(session) def test_correct_pin(session: Session): - with session, session.client as client: + with session.client as client: client.use_pin_sequence([PIN4]) # Expected responses differ between T1 and TT is_t1 = session.model is models.T1B1 - session.set_expected_responses( + client.set_expected_responses( [ (is_t1, messages.PinMatrixRequest), ( @@ -65,10 +65,10 @@ def test_incorrect_pin_t1(session: Session): @pytest.mark.models("core") def test_incorrect_pin_t2(session: Session): - with session, session.client as client: + with session.client as client: # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt client.use_pin_sequence([BAD_PIN, PIN4]) - session.set_expected_responses( + client.set_expected_responses( [ messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), @@ -82,7 +82,7 @@ def test_incorrect_pin_t2(session: Session): def test_exponential_backoff_t1(session: Session): for attempt in range(3): start = time.time() - with session, session.client as client, pytest.raises(PinException): + with session.client as client, pytest.raises(PinException): client.use_pin_sequence([BAD_PIN]) get_test_address(session) check_pin_backoff_time(attempt, start) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index 90632ec95a..b38cb7f34e 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -97,7 +97,7 @@ def test_passphrase_reporting(session: Session, passphrase): """On TT, passphrase_protection is a private setting, so a locked device should report passphrase_protection=None. """ - with session, session.client as client: + with session.client as client: client.use_pin_sequence([PIN4]) device.apply_settings(session, use_passphrase=passphrase) @@ -164,7 +164,7 @@ def test_change_pin_t2(client: Client): _pin_request(client), _pin_request(client), ( - session.client.layout_type is LayoutType.Caesar, + client.layout_type is LayoutType.Caesar, messages.ButtonRequest, ), _pin_request(client), @@ -238,7 +238,7 @@ def test_wipe_device(client: Client): session = client.get_session() client.set_expected_responses([messages.ButtonRequest, messages.Success]) device.wipe(session) - client = session.client.get_new_client() + client = client.get_new_client() session = client.get_seedless_session() with client: client.set_expected_responses([messages.Features]) @@ -251,8 +251,8 @@ def test_wipe_device(client: Client): def test_reset_device(session: Session): assert session.features.pin_protection is False assert session.features.passphrase_protection is False - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [messages.ButtonRequest] + [messages.EntropyRequest] + [messages.ButtonRequest] * 24 @@ -289,8 +289,8 @@ def test_recovery_device(session: Session, uninitialized_session=True): assert session.features.pin_protection is False assert session.features.passphrase_protection is False session.client.use_mnemonic(MNEMONIC12) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [messages.ButtonRequest] + [messages.WordRequest] * 24 + [messages.Success] # , messages.Features] @@ -302,7 +302,7 @@ def test_recovery_device(session: Session, uninitialized_session=True): False, False, "label", - input_callback=session.client.mnemonic_callback, + input_callback=client.mnemonic_callback, ) with pytest.raises(TrezorFailure): diff --git a/tests/device_tests/test_repeated_backup.py b/tests/device_tests/test_repeated_backup.py index 601c898fbb..77b10ad455 100644 --- a/tests/device_tests/test_repeated_backup.py +++ b/tests/device_tests/test_repeated_backup.py @@ -34,12 +34,13 @@ pytestmark = pytest.mark.models("core") @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) def test_repeated_backup(session: Session): + client = session.client assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with session, session.client as client: + with client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) device.backup(session) @@ -56,7 +57,7 @@ def test_repeated_backup(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) @@ -69,7 +70,7 @@ def test_repeated_backup(session: Session): assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with session, session.client as client: + with client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) device.backup(session) @@ -85,6 +86,7 @@ def test_repeated_backup(session: Session): @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20) def test_repeated_backup_upgrade_single(session: Session): + client = session.client assert ( session.features.backup_availability == messages.BackupAvailability.NotAvailable ) @@ -92,7 +94,7 @@ def test_repeated_backup_upgrade_single(session: Session): assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable # unlock repeated backup by entering the single share - with session, session.client as client: + with client: IF = InputFlowSlip39BasicRecoveryDryRun( client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True ) @@ -105,7 +107,7 @@ def test_repeated_backup_upgrade_single(session: Session): assert session.features.recovery_status == messages.RecoveryStatus.Backup # we can now perform another backup - with session, session.client as client: + with client: IF = InputFlowSlip39BasicBackup(client, False, repeated=True) client.set_input_flow(IF.get()) device.backup(session) @@ -123,12 +125,13 @@ def test_repeated_backup_upgrade_single(session: Session): @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) def test_repeated_backup_cancel(session: Session): + client = session.client assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with session, session.client as client: + with client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) device.backup(session) @@ -145,7 +148,7 @@ def test_repeated_backup_cancel(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) @@ -157,7 +160,7 @@ def test_repeated_backup_cancel(session: Session): ) assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = session.client.debug.read_layout() + layout = client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a Cancel message @@ -178,12 +181,13 @@ def test_repeated_backup_cancel(session: Session): @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) def test_repeated_backup_send_disallowed_message(session: Session): + client = session.client assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.recovery_status == messages.RecoveryStatus.Nothing # initial device backup mnemonics = [] - with session, session.client as client: + with client: IF = InputFlowSlip39BasicBackup(client, False) client.set_input_flow(IF.get()) device.backup(session) @@ -200,7 +204,7 @@ def test_repeated_backup_send_disallowed_message(session: Session): device.backup(session) # unlock repeated backup by entering 3 of the 5 shares we have got - with session, session.client as client: + with client: IF = InputFlowSlip39BasicRecoveryDryRun( client, mnemonics[:3], unlock_repeated_backup=True ) @@ -212,7 +216,7 @@ def test_repeated_backup_send_disallowed_message(session: Session): ) assert session.features.recovery_status == messages.RecoveryStatus.Backup - layout = session.client.debug.read_layout() + layout = client.debug.read_layout() assert TR.recovery__unlock_repeated_backup in layout.text_content() # send a GetAddress message @@ -233,8 +237,7 @@ def test_repeated_backup_send_disallowed_message(session: Session): # we are still on the confirmation screen! assert ( - TR.recovery__unlock_repeated_backup - in session.client.debug.read_layout().text_content() + TR.recovery__unlock_repeated_backup in client.debug.read_layout().text_content() ) with pytest.raises(exceptions.Cancelled): session.call(messages.Cancel()) diff --git a/tests/device_tests/test_sdcard.py b/tests/device_tests/test_sdcard.py index 8d5c45b81f..3d3584e107 100644 --- a/tests/device_tests/test_sdcard.py +++ b/tests/device_tests/test_sdcard.py @@ -36,7 +36,8 @@ def test_sd_format(session: Session): @pytest.mark.sd_card(formatted=False) def test_sd_no_format(session: Session): - debug = session.client.debug + client = session.client + debug = client.debug def input_flow(): yield # enable SD protection? @@ -45,7 +46,7 @@ def test_sd_no_format(session: Session): yield # format SD card debug.press_no() - with session, session.client as client, pytest.raises(TrezorFailure) as e: + with client, pytest.raises(TrezorFailure) as e: client.set_input_flow(input_flow) device.sd_protect(session, Op.ENABLE) @@ -55,7 +56,8 @@ def test_sd_no_format(session: Session): @pytest.mark.sd_card @pytest.mark.setup_client(pin=PIN) def test_sd_protect_unlock(session: Session): - debug = session.client.debug + client = session.client + debug = client.debug layout = debug.read_layout def input_flow_enable_sd_protect(): @@ -76,7 +78,7 @@ def test_sd_protect_unlock(session: Session): assert TR.sd_card__enabled in layout().text_content() debug.press_yes() - with session, session.client as client: + with client: client.watch_layout() client.set_input_flow(input_flow_enable_sd_protect) device.sd_protect(session, Op.ENABLE) @@ -102,7 +104,7 @@ def test_sd_protect_unlock(session: Session): assert TR.pin__changed in layout().text_content() debug.press_yes() - with session, session.client as client: + with client: client.watch_layout() client.set_input_flow(input_flow_change_pin) device.change_pin(session) @@ -125,7 +127,7 @@ def test_sd_protect_unlock(session: Session): ) debug.press_no() # close - with session, session.client as client, pytest.raises(TrezorFailure) as e: + with client, pytest.raises(TrezorFailure) as e: client.watch_layout() client.set_input_flow(input_flow_change_pin_format) device.change_pin(session) diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index ebf387333a..81d2bfd6ec 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -71,9 +71,9 @@ def test_clear_session(client: Client): assert _get_public_node(session, ADDRESS_N).xpub == XPUB session.resume() - with session: + with client: # pin and passphrase are cached - session.set_expected_responses(cached_responses) + client.set_expected_responses(cached_responses) assert _get_public_node(session, ADDRESS_N).xpub == XPUB session.lock() @@ -87,9 +87,9 @@ def test_clear_session(client: Client): assert _get_public_node(session, ADDRESS_N).xpub == XPUB session.resume() - with session: + with client: # pin and passphrase are cached - session.set_expected_responses(cached_responses) + client.set_expected_responses(cached_responses) assert _get_public_node(session, ADDRESS_N).xpub == XPUB @@ -100,8 +100,8 @@ def test_end_session(client: Client): assert session.id is not None # get_address will succeed - with session: - session.set_expected_responses([messages.Address]) + with client: + client.set_expected_responses([messages.Address]) get_test_address(session) session.end() @@ -113,13 +113,13 @@ def test_end_session(client: Client): session = client.get_session() assert session.id is not None - with session: - session.set_expected_responses([messages.Address]) + with client: + client.set_expected_responses([messages.Address]) get_test_address(session) - with session as session: + with client: # end_session should succeed on empty session too - session.set_expected_responses([messages.Success] * 2) + client.set_expected_responses([messages.Success] * 2) session.end() session.end() @@ -162,8 +162,8 @@ def test_end_session_only_current(client: Client): @pytest.mark.setup_client(passphrase=True) def test_session_recycling(client: Client): session = client.get_session(passphrase="TREZOR") - with session: - session.set_expected_responses([messages.Address]) + with client: + client.set_expected_responses([messages.Address]) address = get_test_address(session) # create and close 100 sessions - more than the session limit @@ -172,7 +172,7 @@ def test_session_recycling(client: Client): session_x.end() # it should still be possible to resume the original session - with client, session: + with client: # passphrase should still be cached expected_responses = [messages.Address] * 3 if client.protocol_version == ProtocolVersion.V1: diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 9e896f6833..77563665b4 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -65,8 +65,8 @@ def _get_xpub( else: expected_responses = [messages.PublicKey] - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) result = session.call_raw(XPUB_REQUEST) if passphrase is not None: result = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) @@ -430,7 +430,7 @@ def test_passphrase_length(client: Client): def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails session = client.get_seedless_session() - with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: + with pytest.raises(TrezorFailure, match="Safety checks are strict"): device.apply_settings(session, hide_passphrase_from_host=True) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) @@ -440,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client): passphrase = "abc" session = _get_session(client) - with session: + with client: def input_flow(): yield @@ -455,9 +455,9 @@ def test_hide_passphrase_from_host(client: Client): else: raise KeyError - client.watch_layout() + session.client.watch_layout() client.set_input_flow(input_flow) - session.set_expected_responses( + client.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, @@ -476,7 +476,7 @@ def test_hide_passphrase_from_host(client: Client): # Starting new session, otherwise the passphrase would be cached session = _get_session(client) - with client, session: + with client: def input_flow(): yield @@ -491,9 +491,9 @@ def test_hide_passphrase_from_host(client: Client): assert passphrase in client.debug.read_layout().text_content() client.debug.press_yes() - client.watch_layout() + session.client.watch_layout() client.set_input_flow(input_flow) - session.set_expected_responses( + client.set_expected_responses( [ messages.PassphraseRequest, messages.ButtonRequest, diff --git a/tests/device_tests/tezos/test_getaddress.py b/tests/device_tests/tezos/test_getaddress.py index 9f35118370..f9fde6aade 100644 --- a/tests/device_tests/tezos/test_getaddress.py +++ b/tests/device_tests/tezos/test_getaddress.py @@ -45,7 +45,7 @@ def test_tezos_get_address_chunkify_details( session: Session, path: str, expected_address: str ): with session.client as client: - IF = InputFlowShowAddressQRCode(client) + IF = InputFlowShowAddressQRCode(session.client) client.set_input_flow(IF.get()) address = get_address( session, parse_path(path), show_display=True, chunkify=True diff --git a/tests/device_tests/tezos/test_sign_tx.py b/tests/device_tests/tezos/test_sign_tx.py index f70a4934d9..5e390779f5 100644 --- a/tests/device_tests/tezos/test_sign_tx.py +++ b/tests/device_tests/tezos/test_sign_tx.py @@ -33,7 +33,7 @@ pytestmark = [ def test_tezos_sign_tx_proposal(session: Session): - with session: + with session.client: resp = tezos.sign_tx( session, TEZOS_PATH_10, @@ -64,7 +64,7 @@ def test_tezos_sign_tx_proposal(session: Session): def test_tezos_sign_tx_multiple_proposals(session: Session): - with session: + with session.client: resp = tezos.sign_tx( session, TEZOS_PATH_10, diff --git a/tests/device_tests/webauthn/test_msg_webauthn.py b/tests/device_tests/webauthn/test_msg_webauthn.py index 7016e2f5f8..de8644ef5c 100644 --- a/tests/device_tests/webauthn/test_msg_webauthn.py +++ b/tests/device_tests/webauthn/test_msg_webauthn.py @@ -31,8 +31,8 @@ RK_CAPACITY = 100 @pytest.mark.altcoin @pytest.mark.setup_client(mnemonic=MNEMONIC12) def test_add_remove(session: Session): - with session, session.client as client: - IF = InputFlowFidoConfirm(client) + with session.client as client: + IF = InputFlowFidoConfirm(session.client) client.set_input_flow(IF.get()) # Remove index 0 should fail. diff --git a/tests/device_tests/zcash/test_sign_tx.py b/tests/device_tests/zcash/test_sign_tx.py index 4d7df80090..ab22a19057 100644 --- a/tests/device_tests/zcash/test_sign_tx.py +++ b/tests/device_tests/zcash/test_sign_tx.py @@ -95,8 +95,8 @@ def test_spend_v4_input(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -143,8 +143,8 @@ def test_send_to_multisig(session: Session): script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -190,8 +190,8 @@ def test_spend_v5_input(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -243,8 +243,8 @@ def test_one_two(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -301,8 +301,8 @@ def test_unified_address(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_output(0), @@ -365,8 +365,8 @@ def test_external_presigned(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - with session: - session.set_expected_responses( + with session.client as client: + client.set_expected_responses( [ request_input(0), request_input(1), @@ -489,8 +489,8 @@ def test_spend_multisig(session: Session): request_finished(), ] - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) signatures1, _ = btc.sign_tx( session, "Zcash Testnet", @@ -529,8 +529,8 @@ def test_spend_multisig(session: Session): multisig=multisig, ) - with session: - session.set_expected_responses(expected_responses) + with session.client as client: + client.set_expected_responses(expected_responses) signatures2, serialized_tx = btc.sign_tx( session, "Zcash Testnet", diff --git a/tests/input_flows.py b/tests/input_flows.py index 81680b6663..8044e6b9dd 100644 --- a/tests/input_flows.py +++ b/tests/input_flows.py @@ -50,16 +50,18 @@ class InputFlowBase: # There could be one common input flow for all models if hasattr(self, "input_flow_common"): - return getattr(self, "input_flow_common") + flow = getattr(self, "input_flow_common") elif self.client.layout_type is LayoutType.Bolt: - return self.input_flow_bolt + flow = self.input_flow_bolt elif self.client.layout_type is LayoutType.Caesar: - return self.input_flow_caesar + flow = self.input_flow_caesar elif self.client.layout_type is LayoutType.Delizia: - return self.input_flow_delizia + flow = self.input_flow_delizia else: raise ValueError("Unknown model") + return flow + def input_flow_bolt(self) -> BRGeneratorType: """Special for TT""" raise NotImplementedError @@ -371,7 +373,7 @@ class InputFlowSignMessageInfo(InputFlowBase): self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1]) # address mismatch? yes! self.debug.swipe_up() - yield + yield # ? class InputFlowShowAddressQRCode(InputFlowBase): diff --git a/tests/persistence_tests/test_wipe_code.py b/tests/persistence_tests/test_wipe_code.py index 8dee771a6a..0b2df11cfb 100644 --- a/tests/persistence_tests/test_wipe_code.py +++ b/tests/persistence_tests/test_wipe_code.py @@ -11,33 +11,37 @@ WIPE_CODE = "9876" def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client.get_seedless_session()) + session = client.get_seedless_session() + device.wipe(session) client = client.get_new_client() + session = client.get_seedless_session() debuglink.load_device( - client.get_seedless_session(), + session, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE", ) - with client: + with session.client as client: client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) device.change_wipe_code(client.get_seedless_session()) def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: - device.wipe(client.get_seedless_session()) + session = client.get_seedless_session() + device.wipe(session) client = client.get_new_client() + session = client.get_seedless_session() debuglink.load_device( - client.get_seedless_session(), + session, MNEMONIC12, pin, passphrase_protection=False, label="WIPECODE", ) - with client: + with session.client as client: client.use_pin_sequence([pin, wipe_code, wipe_code]) device.change_wipe_code(client.get_seedless_session()) diff --git a/tests/translations.py b/tests/translations.py index e17bdbd9b3..b00d3652f1 100644 --- a/tests/translations.py +++ b/tests/translations.py @@ -69,7 +69,7 @@ def set_language(session: Session, lang: str, *, force: bool = True): language_data = b"" else: language_data = build_and_sign_blob(lang, session) - with session: + with session.client: if not session.features.language.startswith(lang) or force: device.change_language(session, language_data) # type: ignore _CURRENT_TRANSLATION.TR = TRANSLATIONS[lang]