1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-16 08:06:05 +00:00

refactor(tests): move set_input_flow to SessionDebugWrapper context manager

[no changelog]
This commit is contained in:
Martin Milata 2025-03-06 00:35:51 +01:00 committed by M1nd3r
parent 327fe1f98b
commit 46ac326354
72 changed files with 632 additions and 668 deletions

View File

@ -789,10 +789,10 @@ class DebugUI:
def __init__(self, debuglink: DebugLink) -> None: def __init__(self, debuglink: DebugLink) -> None:
self.debuglink = debuglink self.debuglink = debuglink
self.pins: t.Iterator[str] | None = None
self.clear() self.clear()
def clear(self) -> None: def clear(self) -> None:
self.pins: t.Iterator[str] | None = None
self.passphrase = None self.passphrase = None
self.input_flow: t.Union[ self.input_flow: t.Union[
t.Generator[None, messages.ButtonRequest, None], object, None t.Generator[None, messages.ButtonRequest, None], object, None
@ -964,12 +964,10 @@ class SessionDebugWrapper(Session):
return self.client.protocol_version return self.client.protocol_version
def _write(self, msg: t.Any) -> None: def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__)
self._session._write(self._filter_message(msg)) self._session._write(self._filter_message(msg))
def _read(self) -> t.Any: def _read(self) -> t.Any:
resp = self._filter_message(self._session._read()) resp = self._filter_message(self._session._read())
print("reading message:", resp.__class__.__name__)
if self.actual_responses is not None: if self.actual_responses is not None:
self.actual_responses.append(resp) self.actual_responses.append(resp)
return resp return resp
@ -1067,6 +1065,7 @@ class SessionDebugWrapper(Session):
Clears all debugging state that might have been modified by a testcase. Clears all debugging state that might have been modified by a testcase.
""" """
self.client.ui.clear() # type: ignore [Cannot access attribute]
self.in_with_statement = False self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None self.actual_responses: list[protobuf.MessageType] | None = None
@ -1103,7 +1102,6 @@ class SessionDebugWrapper(Session):
# If no other exception was raised, evaluate missed responses # If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch) # (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses) self._verify_responses(expected_responses, actual_responses)
elif isinstance(input_flow, t.Generator): elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in # Propagate the exception through the input flow, so that we see in
# traceback where it is stuck. # traceback where it is stuck.
@ -1163,6 +1161,45 @@ class SessionDebugWrapper(Session):
output.append("") output.append("")
return output return output
def set_input_flow(
self,
input_flow: InputFlowType | t.Callable[[], InputFlowType],
) -> None:
"""Configure a sequence of input events for the current with-block.
The `input_flow` must be a generator function. A `yield` statement in the
input flow function waits for a ButtonRequest from the device, and returns
its code.
Example usage:
>>> def input_flow():
>>> # wait for first button prompt
>>> code = yield
>>> assert code == ButtonRequestType.Other
>>> # press No
>>> client.debug.press_no()
>>>
>>> # wait for second button prompt
>>> yield
>>> # press Yes
>>> client.debug.press_yes()
>>>
>>> with session:
>>> session.set_input_flow(input_flow)
>>> some_call(session)
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
if callable(input_flow):
input_flow = input_flow()
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.client.ui.input_flow = input_flow # type: ignore [Cannot access attribute]
next(input_flow) # start the generator
class TrezorClientDebugLink(TrezorClient): class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses # This class implements automatic responses
@ -1204,7 +1241,6 @@ class TrezorClientDebugLink(TrezorClient):
self.transport = transport self.transport = transport
self.ui: DebugUI = DebugUI(self.debug) self.ui: DebugUI = DebugUI(self.debug)
self.reset_debug_features()
self._seedless_session = self.get_seedless_session(new_session=True) self._seedless_session = self.get_seedless_session(new_session=True)
self.sync_responses() self.sync_responses()
@ -1229,15 +1265,6 @@ class TrezorClientDebugLink(TrezorClient):
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
return new_client return new_client
def reset_debug_features(self) -> None:
"""
Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False
def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any: def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later # do this raw - send ButtonAck first, notify UI later
@ -1366,43 +1393,6 @@ class TrezorClientDebugLink(TrezorClient):
else: else:
return SessionDebugWrapper(super().resume_session(session)) return SessionDebugWrapper(super().resume_session(session))
def set_input_flow(
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
) -> None:
"""Configure a sequence of input events for the current with-block.
The `input_flow` must be a generator function. A `yield` statement in the
input flow function waits for a ButtonRequest from the device, and returns
its code.
Example usage:
>>> def input_flow():
>>> # wait for first button prompt
>>> code = yield
>>> assert code == ButtonRequestType.Other
>>> # press No
>>> client.debug.press_no()
>>>
>>> # wait for second button prompt
>>> yield
>>> # press Yes
>>> client.debug.press_yes()
>>>
>>> with client:
>>> client.set_input_flow(input_flow)
>>> some_call(client)
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
if callable(input_flow):
input_flow = input_flow()
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow
next(input_flow) # start the generator
def watch_layout(self, watch: bool = True) -> None: def watch_layout(self, watch: bool = True) -> None:
"""Enable or disable watching layout changes. """Enable or disable watching layout changes.
@ -1416,29 +1406,6 @@ class TrezorClientDebugLink(TrezorClient):
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug # - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
self.debug.watch_layout(watch) self.debug.watch_layout(watch)
def __enter__(self) -> "TrezorClientDebugLink":
# For usage in with/expected_responses
if self.in_with_statement:
raise RuntimeError("Do not nest!")
self.in_with_statement = True
return self
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.ui, DebugUI):
input_flow = self.ui.input_flow
else:
input_flow = None
self.reset_debug_features()
if exc_type is not None and isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
def use_pin_sequence(self, pins: t.Iterable[str]) -> None: def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs. """Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts. The sequence must be at least as long as the expected number of PIN prompts.
@ -1450,25 +1417,6 @@ class TrezorClientDebugLink(TrezorClient):
Only applies to T1, where device prompts the host for mnemonic words.""" Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: list[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
for i in range(start_at, stop_at):
exp = expected[i]
prefix = " " if i != current else ">>> "
output.append(textwrap.indent(exp.to_string(), prefix))
if stop_at < len(expected):
omitted = len(expected) - stop_at
output.append(f" (...{omitted} following responses omitted)")
output.append("")
return output
def sync_responses(self) -> None: def sync_responses(self) -> None:
"""Synchronize Trezor device receiving with caller. """Synchronize Trezor device receiving with caller.

View File

@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str):
if __name__ == "__main__": if __name__ == "__main__":
wirelink = get_device() wirelink = get_device()
client = Client(wirelink) client = Client(wirelink)
client.open() session = client.get_seedless_session()
i = 0 i = 0
@ -76,10 +76,12 @@ if __name__ == "__main__":
# change PIN # change PIN
new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10))) new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10)))
client.set_input_flow(pin_input_flow(client, last_pin, new_pin)) session.set_input_flow(pin_input_flow(client, last_pin, new_pin))
device.change_pin(client) device.change_pin(client)
client.set_input_flow(None) session.set_input_flow(None)
last_pin = new_pin last_pin = new_pin
print(f"iteration {i}") print(f"iteration {i}")
i = i + 1 i = i + 1
wirelink.close()

View File

@ -198,7 +198,7 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
session.set_filter(messages.TxAck, None) session.set_filter(messages.TxAck, None)
return msg return msg
with session, device_handler.client: with session:
session.set_filter(messages.TxAck, sleepy_filter) session.set_filter(messages.TxAck, sleepy_filter)
# confirm transaction # confirm transaction
if debug.layout_type is LayoutType.Bolt: if debug.layout_type is LayoutType.Bolt:

View File

@ -283,7 +283,7 @@ def _client_unlocked(
test_ui = request.config.getoption("ui") test_ui = request.config.getoption("ui")
_raw_client.reset_debug_features() # _raw_client.reset_debug_features()
if isinstance(_raw_client.protocol, ProtocolV1Channel): if isinstance(_raw_client.protocol, ProtocolV1Channel):
try: try:
_raw_client.sync_responses() _raw_client.sync_responses()

View File

@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1
class NullUI: class NullUI:
@staticmethod
def clear(*args, **kwargs):
pass
@staticmethod @staticmethod
def button_request(code): def button_request(code):
pass pass

View File

@ -51,9 +51,9 @@ def test_binance_get_address_chunkify_details(
): ):
# data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50
with session.client as client: with session:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address = get_address( address = get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True
) )

View File

@ -32,9 +32,9 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0")
mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin" mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin"
) )
def test_binance_get_public_key(session: Session): def test_binance_get_public_key(session: Session):
with session.client as client: with session:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) sig = binance.get_public_key(session, BINANCE_PATH, show_display=True)
assert ( assert (
sig.hex() sig.hex()

View File

@ -65,8 +65,8 @@ def test_sign_tx(session: Session, chunkify: bool):
assert session.features.unlocked is False assert session.features.unlocked is False
commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big")
with session.client as client: with session:
client.use_pin_sequence([PIN]) session.client.use_pin_sequence([PIN])
btc.authorize_coinjoin( btc.authorize_coinjoin(
session, session,
coordinator="www.example.com", coordinator="www.example.com",

View File

@ -168,9 +168,9 @@ def _address_n(purpose, coin, account, script_type):
def test_descriptors( def test_descriptors(
session: Session, coin, account, purpose, script_type, descriptors session: Session, coin, account, purpose, script_type, descriptors
): ):
with session.client as client: with session:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address_n = _address_n(purpose, coin, account, script_type) address_n = _address_n(purpose, coin, account, script_type)
res = btc.get_public_node( res = btc.get_public_node(
@ -191,10 +191,10 @@ def test_descriptors(
def test_descriptors_trezorlib( def test_descriptors_trezorlib(
session: Session, coin, account, purpose, script_type, descriptors session: Session, coin, account, purpose, script_type, descriptors
): ):
with session.client as client: with session:
if client.model != models.T1B1: if session.client.model != models.T1B1:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
res = btc_cli._get_descriptor( res = btc_cli._get_descriptor(
session, coin, account, purpose, script_type, show_display=True session, coin, account, purpose, script_type, show_display=True
) )

View File

@ -270,10 +270,10 @@ def test_multisig(session: Session):
xpubs.append(node.xpub) xpubs.append(node.xpub)
for nr in range(1, 4): for nr in range(1, 4):
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
session, session,
@ -321,10 +321,10 @@ def test_multisig_missing(session: Session, show_display):
) )
for multisig in (multisig1, multisig2): for multisig in (multisig1, multisig2):
with session.client as client, pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure), session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,
"Bitcoin", "Bitcoin",
@ -345,10 +345,10 @@ def test_bch_multisig(session: Session):
xpubs.append(node.xpub) xpubs.append(node.xpub)
for nr in range(1, 4): for nr in range(1, 4):
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
session, session,
@ -406,7 +406,7 @@ def test_unknown_path(session: Session):
# disable safety checks # disable safety checks
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest( messages.ButtonRequest(
@ -417,8 +417,8 @@ def test_unknown_path(session: Session):
] ]
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# try again with a warning # try again with a warning
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True)
@ -455,10 +455,10 @@ def test_multisig_different_paths(session: Session):
with pytest.raises( with pytest.raises(
Exception, match="Using different paths for different xpubs is not allowed" Exception, match="Using different paths for different xpubs is not allowed"
): ):
with session.client as client, session: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,
"Bitcoin", "Bitcoin",
@ -469,10 +469,10 @@ def test_multisig_different_paths(session: Session):
) )
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,
"Bitcoin", "Bitcoin",

View File

@ -74,10 +74,10 @@ def test_show_segwit(session: Session):
@pytest.mark.altcoin @pytest.mark.altcoin
def test_show_segwit_altcoin(session: Session): def test_show_segwit_altcoin(session: Session):
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
session, session,

View File

@ -63,9 +63,9 @@ def test_show_t1(
yield yield
session.client.debug.press_yes() session.client.debug.press_yes()
with session.client as client: with session:
# This is the only place where even T1 is using input flow # This is the only place where even T1 is using input flow
client.set_input_flow(input_flow_t1) session.set_input_flow(input_flow_t1)
assert ( assert (
btc.get_address( btc.get_address(
session, session,
@ -88,9 +88,9 @@ def test_show_tt(
script_type: messages.InputScriptType, script_type: messages.InputScriptType,
address: str, address: str,
): ):
with session.client as client: with session:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
session, session,
@ -109,9 +109,9 @@ def test_show_tt(
def test_show_cancel( def test_show_cancel(
session: Session, path: str, script_type: messages.InputScriptType, address: str session: Session, path: str, script_type: messages.InputScriptType, address: str
): ):
with session.client as client, pytest.raises(Cancelled): with session, pytest.raises(Cancelled):
IF = InputFlowShowAddressQRCodeCancel(client) IF = InputFlowShowAddressQRCodeCancel(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,
"Bitcoin", "Bitcoin",
@ -157,10 +157,10 @@ def test_show_multisig_3(session: Session):
for multisig in (multisig1, multisig2): for multisig in (multisig1, multisig2):
for i in [1, 2, 3]: for i in [1, 2, 3]:
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
session, session,
@ -273,11 +273,11 @@ def test_show_multisig_xpubs(
) )
for i in range(3): for i in range(3):
with session, session.client as client: with session:
IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) IF = InputFlowShowMultisigXPUBs(session.client, address, xpubs, i)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
client.debug.synchronize_at("Homescreen") session.client.debug.synchronize_at("Homescreen")
client.watch_layout() session.client.watch_layout()
btc.get_address( btc.get_address(
session, session,
"Bitcoin", "Bitcoin",
@ -314,10 +314,10 @@ def test_show_multisig_15(session: Session):
for multisig in [multisig1, multisig2]: for multisig in [multisig1, multisig2]:
for i in range(15): for i in range(15):
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
session, session,

View File

@ -119,9 +119,9 @@ def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub):
@pytest.mark.models("core") @pytest.mark.models("core")
@pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN)
def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub):
with session.client as client: with session:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub assert res.xpub == xpub
assert bip32.serialize(res.node, xpub_magic) == xpub assert bip32.serialize(res.node, xpub_magic) == xpub
@ -158,14 +158,14 @@ def test_get_public_node_show_legacy(
client.debug.press_yes() # finish the flow client.debug.press_yes() # finish the flow
yield yield
with client: with session:
# test XPUB display flow (without showing QR code) # test XPUB display flow (without showing QR code)
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub assert res.xpub == xpub
assert bip32.serialize(res.node, xpub_magic) == xpub assert bip32.serialize(res.node, xpub_magic) == xpub
# test XPUB QR code display using the input flow above # test XPUB QR code display using the input flow above
client.set_input_flow(input_flow) session.set_input_flow(input_flow)
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub assert res.xpub == xpub
assert bip32.serialize(res.node, xpub_magic) == xpub assert bip32.serialize(res.node, xpub_magic) == xpub

View File

@ -475,10 +475,10 @@ def test_attack_change_input(session: Session):
) )
# Transaction can be signed without the attack processor # Transaction can be signed without the attack processor
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
"Testnet", "Testnet",

View File

@ -288,7 +288,7 @@ def test_external_internal(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
_responses( _responses(
session, session,
@ -299,8 +299,8 @@ def test_external_internal(session: Session):
) )
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",
@ -324,7 +324,7 @@ def test_internal_external(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
_responses( _responses(
session, session,
@ -335,8 +335,8 @@ def test_internal_external(session: Session):
) )
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",

View File

@ -113,10 +113,10 @@ def test_getaddress(
script_types: list[messages.InputScriptType], script_types: list[messages.InputScriptType],
): ):
for script_type in script_types: for script_type in script_types:
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
res = btc.get_address( res = btc.get_address(
session, session,
"Bitcoin", "Bitcoin",
@ -134,10 +134,10 @@ def test_signmessage(
session: Session, path: str, script_types: list[messages.InputScriptType] session: Session, path: str, script_types: list[messages.InputScriptType]
): ):
for script_type in script_types: for script_type in script_types:
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
sig = btc.sign_message( sig = btc.sign_message(
session, session,
@ -175,10 +175,10 @@ def test_signtx(
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}
) )
@ -202,10 +202,10 @@ def test_getaddress_multisig(
] ]
multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2) multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2)
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address = btc.get_address( address = btc.get_address(
session, session,
"Bitcoin", "Bitcoin",
@ -261,10 +261,10 @@ def test_signtx_multisig(session: Session, paths: list[str], address_index: list
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
sig, _ = btc.sign_tx( sig, _ = btc.sign_tx(
session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}
) )

View File

@ -327,9 +327,9 @@ def test_signmessage_long(
message: str, message: str,
signature: str, signature: str,
): ):
with session.client as client: with session:
IF = InputFlowSignVerifyMessageLong(client) IF = InputFlowSignVerifyMessageLong(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
sig = btc.sign_message( sig = btc.sign_message(
session, session,
coin_name=coin_name, coin_name=coin_name,
@ -356,9 +356,9 @@ def test_signmessage_info(
message: str, message: str,
signature: str, signature: str,
): ):
with session.client as client, pytest.raises(Cancelled): with session, pytest.raises(Cancelled):
IF = InputFlowSignMessageInfo(client) IF = InputFlowSignMessageInfo(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
sig = btc.sign_message( sig = btc.sign_message(
session, session,
coin_name=coin_name, coin_name=coin_name,
@ -390,13 +390,13 @@ MESSAGE_LENGTHS = (
@pytest.mark.models("core") @pytest.mark.models("core")
@pytest.mark.parametrize("message,is_long", MESSAGE_LENGTHS) @pytest.mark.parametrize("message,is_long", MESSAGE_LENGTHS)
def test_signmessage_pagination(session: Session, message: str, is_long: bool): def test_signmessage_pagination(session: Session, message: str, is_long: bool):
with session.client as client: with session:
IF = ( IF = (
InputFlowSignVerifyMessageLong InputFlowSignVerifyMessageLong
if is_long if is_long
else InputFlowSignMessagePagination else InputFlowSignMessagePagination
)(client) )(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_message( btc.sign_message(
session, session,
coin_name="Bitcoin", coin_name="Bitcoin",
@ -438,7 +438,7 @@ def test_signmessage_pagination_trailing_newline(session: Session):
def test_signmessage_path_warning(session: Session): def test_signmessage_path_warning(session: Session):
message = "This is an example of a signed message." message = "This is an example of a signed message."
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
[ [
# expect a path warning # expect a path warning
@ -451,8 +451,8 @@ def test_signmessage_path_warning(session: Session):
] ]
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_message( btc.sign_message(
session, session,
coin_name="Bitcoin", coin_name="Bitcoin",

View File

@ -664,9 +664,9 @@ def test_fee_high_hardfail(session: Session):
device.apply_settings( device.apply_settings(
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
) )
with session.client as client: with session:
IF = InputFlowSignTxHighFee(client) IF = InputFlowSignTxHighFee(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET
@ -1467,9 +1467,9 @@ def test_lock_time_blockheight(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session.client as client: with session:
IF = InputFlowLockTimeBlockHeight(client, "499999999") IF = InputFlowLockTimeBlockHeight(session.client, "499999999")
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
@ -1506,9 +1506,9 @@ def test_lock_time_datetime(session: Session, lock_time_str: str):
lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc) lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc)
lock_time_timestamp = int(lock_time_utc.timestamp()) lock_time_timestamp = int(lock_time_utc.timestamp())
with session.client as client: with session:
IF = InputFlowLockTimeDatetime(client, lock_time_str) IF = InputFlowLockTimeDatetime(session.client, lock_time_str)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
@ -1538,9 +1538,9 @@ def test_information(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session.client as client: with session:
IF = InputFlowSignTxInformation(client) IF = InputFlowSignTxInformation(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
@ -1573,9 +1573,9 @@ def test_information_mixed(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session.client as client: with session:
IF = InputFlowSignTxInformationMixed(client) IF = InputFlowSignTxInformationMixed(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
@ -1604,9 +1604,9 @@ def test_information_cancel(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session.client as client, pytest.raises(Cancelled): with session, pytest.raises(Cancelled):
IF = InputFlowSignTxInformationCancel(client) IF = InputFlowSignTxInformationCancel(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
@ -1654,9 +1654,9 @@ def test_information_replacement(session: Session):
orig_index=0, orig_index=0,
) )
with session.client as client: with session:
IF = InputFlowSignTxInformationReplacement(client) IF = InputFlowSignTxInformationReplacement(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,

View File

@ -80,10 +80,10 @@ def test_invalid_path_prompt(session: Session):
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
) )
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES)
@ -106,10 +106,10 @@ def test_invalid_path_pass_forkid(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session.client as client: with session:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES)

View File

@ -203,9 +203,9 @@ def test_payment_request_details(session: Session):
) )
] ]
with session.client as client: with session:
IF = InputFlowPaymentRequestDetails(client, outputs) IF = InputFlowPaymentRequestDetails(session.client, outputs)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,

View File

@ -130,11 +130,11 @@ def test_invalid_prev_hash_attack(session: Session, prev_hash):
msg.tx.inputs[0].prev_hash = prev_hash msg.tx.inputs[0].prev_hash = prev_hash
return msg return msg
with session, session.client as client, pytest.raises(TrezorFailure) as e: with session, pytest.raises(TrezorFailure) as e:
session.set_filter(messages.TxAck, attack_filter) session.set_filter(messages.TxAck, attack_filter)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES) btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES)
# check that injection was performed # check that injection was performed
@ -168,9 +168,9 @@ def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash):
tx_hash = hash_tx(serialize_tx(prev_tx)) tx_hash = hash_tx(serialize_tx(prev_tx))
inp0.prev_hash = tx_hash inp0.prev_hash = tx_hash
with session, session.client as client, pytest.raises(TrezorFailure) as e: with session, pytest.raises(TrezorFailure) as e:
if session.model is not models.T1B1: if session.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx})
_check_error_message(prev_hash, session.model, e.value.message) _check_error_message(prev_hash, session.model, e.value.message)

View File

@ -611,11 +611,11 @@ def test_send_multisig_3_change(session: Session):
request_finished(), request_finished(),
] ]
with session, session.client as client: with session:
session.set_expected_responses(expected_responses) session.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -626,11 +626,11 @@ def test_send_multisig_3_change(session: Session):
inp1.address_n[2] = H_(3) inp1.address_n[2] = H_(3)
out1.address_n[2] = H_(3) out1.address_n[2] = H_(3)
with session, session.client as client: with session:
session.set_expected_responses(expected_responses) session.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -703,11 +703,11 @@ def test_send_multisig_4_change(session: Session):
request_finished(), request_finished(),
] ]
with session, session.client as client: with session:
session.set_expected_responses(expected_responses) session.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -718,11 +718,11 @@ def test_send_multisig_4_change(session: Session):
inp1.address_n[2] = H_(3) inp1.address_n[2] = H_(3)
out1.address_n[2] = H_(3) out1.address_n[2] = H_(3)
with session, session.client as client: with session:
session.set_expected_responses(expected_responses) session.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )

View File

@ -40,9 +40,9 @@ def test_message_long_legacy(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_message_long_core(session: Session): def test_message_long_core(session: Session):
with session.client as client: with session:
IF = InputFlowSignVerifyMessageLong(client, verify=True) IF = InputFlowSignVerifyMessageLong(session.client, verify=True)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
ret = btc.verify_message( ret = btc.verify_message(
session, session,
"Bitcoin", "Bitcoin",

View File

@ -95,9 +95,11 @@ def test_cardano_get_address(session: Session, chunkify: bool, parameters, resul
"cardano/get_public_key.derivations.json", "cardano/get_public_key.derivations.json",
) )
def test_cardano_get_public_key(session: Session, parameters, result): def test_cardano_get_public_key(session: Session, parameters, result):
with session, session.client as client: with session:
IF = InputFlowShowXpubQRCode(client, passphrase=bool(session.passphrase)) IF = InputFlowShowXpubQRCode(
client.set_input_flow(IF.get()) session.client, passphrase=bool(session.passphrase)
)
session.set_input_flow(IF.get())
# session.init_device(new_session=True, derive_cardano=True) # session.init_device(new_session=True, derive_cardano=True)
derivation_type = CardanoDerivationType.__members__[ derivation_type = CardanoDerivationType.__members__[

View File

@ -63,7 +63,7 @@ def test_cardano_sign_tx(session: Session, parameters, result):
response = call_sign_tx( response = call_sign_tx(
session, session,
parameters, parameters,
input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), input_flow=lambda client: InputFlowConfirmAllWarnings(session.client).get(),
) )
assert response == _transform_expected_result(result) assert response == _transform_expected_result(result)
@ -122,10 +122,10 @@ def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool =
else: else:
device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict)
with session.client as client: with session:
if input_flow is not None: if input_flow is not None:
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow(client)) session.set_input_flow(input_flow(session.client))
return cardano.sign_tx( return cardano.sign_tx(
session=session, session=session,

View File

@ -29,9 +29,9 @@ from ...input_flows import InputFlowShowXpubQRCode
@pytest.mark.models("t2t1") @pytest.mark.models("t2t1")
@pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_eos_get_public_key(session: Session): def test_eos_get_public_key(session: Session):
with session.client as client: with session:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
public_key = get_public_key( public_key = get_public_key(
session, parse_path("m/44h/194h/0h/0/0"), show_display=True session, parse_path("m/44h/194h/0h/0/0"), show_display=True
) )

View File

@ -123,9 +123,9 @@ def test_external_token(session: Session) -> None:
def test_external_chain_without_token(session: Session) -> None: def test_external_chain_without_token(session: Session) -> None:
with session, session.client as client: with session:
if not client.debug.legacy_debug: if not session.client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
# when using an external chains, unknown tokens are allowed # when using an external chains, unknown tokens are allowed
network = common.encode_network(chain_id=66666, slip44=60) network = common.encode_network(chain_id=66666, slip44=60)
params = DEFAULT_ERC20_PARAMS.copy() params = DEFAULT_ERC20_PARAMS.copy()
@ -145,9 +145,9 @@ def test_external_chain_token_ok(session: Session) -> None:
def test_external_chain_token_mismatch(session: Session) -> None: def test_external_chain_token_mismatch(session: Session) -> None:
with session, session.client as client: with session:
if not client.debug.legacy_debug: if not session.client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
# when providing external defs, we explicitly allow, but not use, tokens # when providing external defs, we explicitly allow, but not use, tokens
# from other chains # from other chains
network = common.encode_network(chain_id=66666, slip44=60) network = common.encode_network(chain_id=66666, slip44=60)

View File

@ -37,9 +37,9 @@ def test_getaddress(session: Session, parameters, result):
@pytest.mark.models("core", reason="No input flow for T1") @pytest.mark.models("core", reason="No input flow for T1")
@parametrize_using_common_fixtures("ethereum/getaddress.json") @parametrize_using_common_fixtures("ethereum/getaddress.json")
def test_getaddress_chunkify_details(session: Session, parameters, result): def test_getaddress_chunkify_details(session: Session, parameters, result):
with session.client as client: with session:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address_n = parse_path(parameters["path"]) address_n = parse_path(parameters["path"])
assert ( assert (
ethereum.get_address(session, address_n, show_display=True, chunkify=True) ethereum.get_address(session, address_n, show_display=True, chunkify=True)

View File

@ -97,10 +97,10 @@ DATA = {
@pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI") @pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI")
def test_ethereum_sign_typed_data_show_more_button(session: Session): def test_ethereum_sign_typed_data_show_more_button(session: Session):
with session.client as client: with session:
client.watch_layout() session.client.watch_layout()
IF = InputFlowEIP712ShowMore(client) IF = InputFlowEIP712ShowMore(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
ethereum.sign_typed_data( ethereum.sign_typed_data(
session, session,
parse_path("m/44h/60h/0h/0/0"), parse_path("m/44h/60h/0h/0/0"),
@ -111,10 +111,10 @@ def test_ethereum_sign_typed_data_show_more_button(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_ethereum_sign_typed_data_cancel(session: Session): def test_ethereum_sign_typed_data_cancel(session: Session):
with session.client as client, pytest.raises(exceptions.Cancelled): with session, pytest.raises(exceptions.Cancelled):
client.watch_layout() session.client.watch_layout()
IF = InputFlowEIP712Cancel(client) IF = InputFlowEIP712Cancel(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
ethereum.sign_typed_data( ethereum.sign_typed_data(
session, session,
parse_path("m/44h/60h/0h/0/0"), parse_path("m/44h/60h/0h/0/0"),

View File

@ -36,9 +36,9 @@ def test_signmessage(session: Session, parameters, result):
assert res.address == result["address"] assert res.address == result["address"]
assert res.signature.hex() == result["sig"] assert res.signature.hex() == result["sig"]
else: else:
with session.client as client: with session:
IF = InputFlowSignVerifyMessageLong(client) IF = InputFlowSignVerifyMessageLong(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
res = ethereum.sign_message( res = ethereum.sign_message(
session, parse_path(parameters["path"]), parameters["msg"] session, parse_path(parameters["path"]), parameters["msg"]
) )
@ -57,9 +57,9 @@ def test_verify(session: Session, parameters, result):
) )
assert res is True assert res is True
else: else:
with session.client as client: with session:
IF = InputFlowSignVerifyMessageLong(client, verify=True) IF = InputFlowSignVerifyMessageLong(session.client, verify=True)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
res = ethereum.verify_message( res = ethereum.verify_message(
session, session,
parameters["address"], parameters["address"],

View File

@ -73,10 +73,10 @@ def _do_test_signtx(
input_flow=None, input_flow=None,
chunkify: bool = False, chunkify: bool = False,
): ):
with session.client as client: with session:
if input_flow: if input_flow:
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow) session.set_input_flow(input_flow)
sig_v, sig_r, sig_s = ethereum.sign_tx( sig_v, sig_r, sig_s = ethereum.sign_tx(
session, session,
n=parse_path(parameters["path"]), n=parse_path(parameters["path"]),
@ -151,9 +151,9 @@ def test_signtx_go_back_from_summary(session: Session):
def test_signtx_eip1559( def test_signtx_eip1559(
session: Session, chunkify: bool, parameters: dict, result: dict session: Session, chunkify: bool, parameters: dict, result: dict
): ):
with session, session.client as client: with session:
if not client.debug.legacy_debug: if not session.client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559(
session, session,
n=parse_path(parameters["path"]), n=parse_path(parameters["path"]),
@ -456,15 +456,15 @@ def test_signtx_data_pagination(session: Session, flow):
data=bytes.fromhex(HEXDATA), data=bytes.fromhex(HEXDATA),
) )
with session, session.client as client: with session:
client.watch_layout() session.client.watch_layout()
client.set_input_flow(flow(client)) session.set_input_flow(flow(session.client))
_sign_tx_call() _sign_tx_call()
if flow is not input_flow_data_scroll_down: if flow is not input_flow_data_scroll_down:
with session, session.client as client, pytest.raises(exceptions.Cancelled): with session, pytest.raises(exceptions.Cancelled):
client.watch_layout() session.client.watch_layout()
client.set_input_flow(flow(client, cancel=True)) session.set_input_flow(flow(session.client, cancel=True))
_sign_tx_call() _sign_tx_call()

View File

@ -33,8 +33,8 @@ def test_encrypt(client: Client):
client.debug.press_yes() client.debug.press_yes()
session = client.get_session() session = client.get_session()
with client, session: with session:
client.set_input_flow(input_flow()) session.set_input_flow(input_flow())
misc.encrypt_keyvalue( misc.encrypt_keyvalue(
session, session,
[], [],

View File

@ -56,9 +56,9 @@ def test_monero_getaddress(session: Session, path: str, expected_address: bytes)
def test_monero_getaddress_chunkify_details( def test_monero_getaddress_chunkify_details(
session: Session, path: str, expected_address: bytes session: Session, path: str, expected_address: bytes
): ):
with session.client as client: with session:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address = monero.get_address( address = monero.get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True
) )

View File

@ -51,10 +51,10 @@ def do_recover_legacy(session: Session, mnemonic: list[str]):
def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False):
with session.client as client: with session:
client.watch_layout() session.client.watch_layout()
IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) IF = InputFlowBip39RecoveryDryRun(session.client, mnemonic, mismatch=mismatch)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
return device.recover(session, type=messages.RecoveryType.DryRun) return device.recover(session, type=messages.RecoveryType.DryRun)
@ -87,10 +87,10 @@ def test_invalid_seed_t1(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_invalid_seed_core(session: Session): def test_invalid_seed_core(session: Session):
with session, session.client as client: with session:
client.watch_layout() session.client.watch_layout()
IF = InputFlowBip39RecoveryDryRunInvalid(session) IF = InputFlowBip39RecoveryDryRunInvalid(session)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
return device.recover( return device.recover(
session, session,

View File

@ -28,9 +28,9 @@ pytestmark = pytest.mark.models("core")
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@pytest.mark.uninitialized_session @pytest.mark.uninitialized_session
def test_tt_pin_passphrase(session: Session): def test_tt_pin_passphrase(session: Session):
with session.client as client: with session:
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "), pin="654")
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
pin_protection=True, pin_protection=True,
@ -49,9 +49,9 @@ def test_tt_pin_passphrase(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@pytest.mark.uninitialized_session @pytest.mark.uninitialized_session
def test_tt_nopin_nopassphrase(session: Session): def test_tt_nopin_nopassphrase(session: Session):
with session.client as client: with session:
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "))
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
pin_protection=False, pin_protection=False,

View File

@ -48,9 +48,11 @@ VECTORS = (
def _test_secret( def _test_secret(
session: Session, shares: list[str], secret: str, click_info: bool = False session: Session, shares: list[str], secret: str, click_info: bool = False
): ):
with session.client as client: with session:
IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) IF = InputFlowSlip39AdvancedRecovery(
client.set_input_flow(IF.get()) session.client, shares, click_info=click_info
)
session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
pin_protection=False, pin_protection=False,
@ -89,9 +91,9 @@ def test_extra_share_entered(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort(session: Session): def test_abort(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39AdvancedRecoveryAbort(client) IF = InputFlowSlip39AdvancedRecoveryAbort(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
@ -100,11 +102,11 @@ def test_abort(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_noabort(session: Session): def test_noabort(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39AdvancedRecoveryNoAbort( IF = InputFlowSlip39AdvancedRecoveryNoAbort(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
assert session.features.initialized is True assert session.features.initialized is True
@ -118,11 +120,11 @@ def test_same_share(session: Session):
# second share is first 4 words of first # second share is first 4 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4]
with session, session.client as client: with session:
IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(
session, first_share, second_share session, first_share, second_share
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@ -134,10 +136,10 @@ def test_group_threshold_reached(session: Session):
# second share is first 3 words of first # second share is first 3 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3]
with session, session.client as client: with session:
IF = InputFlowSlip39AdvancedRecoveryThresholdReached( IF = InputFlowSlip39AdvancedRecoveryThresholdReached(
session, first_share, second_share session, first_share, second_share
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")

View File

@ -40,11 +40,11 @@ EXTRA_GROUP_SHARE = [
@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False)
def test_2of3_dryrun(session: Session): def test_2of3_dryrun(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39AdvancedRecoveryDryRun( IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
passphrase_protection=False, passphrase_protection=False,
@ -57,13 +57,13 @@ def test_2of3_dryrun(session: Session):
@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20)
def test_2of3_invalid_seed_dryrun(session: Session): def test_2of3_invalid_seed_dryrun(session: Session):
# test fails because of different seed on device # test fails because of different seed on device
with session.client as client, pytest.raises( with session, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device" TrezorFailure, match=r"The seed does not match the one in the device"
): ):
IF = InputFlowSlip39AdvancedRecoveryDryRun( IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True session.client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
passphrase_protection=False, passphrase_protection=False,

View File

@ -73,9 +73,9 @@ VECTORS = (
def test_secret( def test_secret(
session: Session, shares: list[str], secret: str, backup_type: messages.BackupType session: Session, shares: list[str], secret: str, backup_type: messages.BackupType
): ):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecovery(client, shares) IF = InputFlowSlip39BasicRecovery(session.client, shares)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended
@ -89,11 +89,11 @@ def test_secret(
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_recover_with_pin_passphrase(session: Session): def test_recover_with_pin_passphrase(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecovery( IF = InputFlowSlip39BasicRecovery(
client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" session.client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654"
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
pin_protection=True, pin_protection=True,
@ -109,9 +109,9 @@ def test_recover_with_pin_passphrase(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort(session: Session): def test_abort(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryAbort(client) IF = InputFlowSlip39BasicRecoveryAbort(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
@ -123,9 +123,9 @@ def test_abort(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort_on_number_of_words(session: Session): def test_abort_on_number_of_words(session: Session):
# on Caesar, test_abort actually aborts on the # of words selection # on Caesar, test_abort actually aborts on the # of words selection
with session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client) IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
assert session.features.initialized is False assert session.features.initialized is False
@ -134,11 +134,11 @@ def test_abort_on_number_of_words(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort_between_shares(session: Session): def test_abort_between_shares(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( IF = InputFlowSlip39BasicRecoveryAbortBetweenShares(
client, MNEMONIC_SLIP39_BASIC_20_3of6 session.client, MNEMONIC_SLIP39_BASIC_20_3of6
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
@ -148,9 +148,11 @@ def test_abort_between_shares(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_noabort(session: Session): def test_noabort(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) IF = InputFlowSlip39BasicRecoveryNoAbort(
client.set_input_flow(IF.get()) session.client, MNEMONIC_SLIP39_BASIC_20_3of6
)
session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
assert session.features.initialized is True assert session.features.initialized is True
@ -158,9 +160,9 @@ def test_noabort(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_invalid_mnemonic_first_share(session: Session): def test_invalid_mnemonic_first_share(session: Session):
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
@ -169,11 +171,11 @@ def test_invalid_mnemonic_first_share(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_invalid_mnemonic_second_share(session: Session): def test_invalid_mnemonic_second_share(session: Session):
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( IF = InputFlowSlip39BasicRecoveryInvalidSecondShare(
session, MNEMONIC_SLIP39_BASIC_20_3of6 session, MNEMONIC_SLIP39_BASIC_20_3of6
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
@ -184,9 +186,9 @@ def test_invalid_mnemonic_second_share(session: Session):
@pytest.mark.parametrize("nth_word", range(3)) @pytest.mark.parametrize("nth_word", range(3))
def test_wrong_nth_word(session: Session, nth_word: int): def test_wrong_nth_word(session: Session, nth_word: int):
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@ -194,18 +196,18 @@ def test_wrong_nth_word(session: Session, nth_word: int):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_same_share(session: Session): def test_same_share(session: Session):
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoverySameShare(session, share) IF = InputFlowSlip39BasicRecoverySameShare(session, share)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_1of1(session: Session): def test_1of1(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) IF = InputFlowSlip39BasicRecovery(session.client, MNEMONIC_SLIP39_BASIC_20_1of1)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
pin_protection=False, pin_protection=False,

View File

@ -38,9 +38,9 @@ INVALID_SHARES_20_2of3 = [
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_dryrun(session: Session): def test_2of3_dryrun(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) IF = InputFlowSlip39BasicRecoveryDryRun(session.client, SHARES_20_2of3[1:3])
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
passphrase_protection=False, passphrase_protection=False,
@ -53,13 +53,13 @@ def test_2of3_dryrun(session: Session):
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_invalid_seed_dryrun(session: Session): def test_2of3_invalid_seed_dryrun(session: Session):
# test fails because of different seed on device # test fails because of different seed on device
with session.client as client, pytest.raises( with session, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device" TrezorFailure, match=r"The seed does not match the one in the device"
): ):
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, INVALID_SHARES_20_2of3, mismatch=True session.client, INVALID_SHARES_20_2of3, mismatch=True
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
passphrase_protection=False, passphrase_protection=False,

View File

@ -32,9 +32,9 @@ from ...input_flows import (
def backup_flow_bip39(session: Session) -> bytes: def backup_flow_bip39(session: Session) -> bytes:
with session.client as client: with session:
IF = InputFlowBip39Backup(client) IF = InputFlowBip39Backup(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
assert IF.mnemonic is not None assert IF.mnemonic is not None
@ -42,9 +42,9 @@ def backup_flow_bip39(session: Session) -> bytes:
def backup_flow_slip39_basic(session: Session): def backup_flow_slip39_basic(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39BasicBackup(client, False) IF = InputFlowSlip39BasicBackup(session.client, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
groups = shamir.decode_mnemonics(IF.mnemonics[:3]) groups = shamir.decode_mnemonics(IF.mnemonics[:3])
@ -53,9 +53,9 @@ def backup_flow_slip39_basic(session: Session):
def backup_flow_slip39_advanced(session: Session): def backup_flow_slip39_advanced(session: Session):
with session.client as client: with session:
IF = InputFlowSlip39AdvancedBackup(client, False) IF = InputFlowSlip39AdvancedBackup(session.client, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13] mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13]
@ -116,9 +116,9 @@ def test_skip_backup_msg(session: Session, backup_type, backup_flow):
def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow):
assert session.features.initialized is False assert session.features.initialized is False
with session, session.client as client: with session:
IF = InputFlowResetSkipBackup(client) IF = InputFlowResetSkipBackup(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.setup( device.setup(
session, session,
pin_protection=False, pin_protection=False,

View File

@ -36,9 +36,9 @@ pytestmark = pytest.mark.models("core")
def reset_device(session: Session, strength: int): def reset_device(session: Session, strength: int):
debug = session.client.debug debug = session.client.debug
with session.client as client: with session:
IF = InputFlowBip39ResetBackup(client) IF = InputFlowBip39ResetBackup(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.setup( device.setup(
@ -92,9 +92,9 @@ def test_reset_device_pin(session: Session):
debug = session.client.debug debug = session.client.debug
strength = 256 # 24 words strength = 256 # 24 words
with session.client as client: with session:
IF = InputFlowBip39ResetPIN(client) IF = InputFlowBip39ResetPIN(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# PIN, passphrase, display random # PIN, passphrase, display random
device.setup( device.setup(
@ -130,9 +130,9 @@ def test_reset_device_pin(session: Session):
def test_reset_entropy_check(session: Session): def test_reset_entropy_check(session: Session):
strength = 128 # 12 words strength = 128 # 12 words
with session.client as client: with session:
IF = InputFlowBip39ResetBackup(client) IF = InputFlowBip39ResetBackup(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# No PIN, no passphrase # No PIN, no passphrase
path_xpubs = device.setup( path_xpubs = device.setup(
@ -147,7 +147,7 @@ def test_reset_entropy_check(session: Session):
) )
# Generate the mnemonic locally. # Generate the mnemonic locally.
internal_entropy = client.debug.state().reset_entropy internal_entropy = session.client.debug.state().reset_entropy
assert internal_entropy is not None assert internal_entropy is not None
entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
@ -156,7 +156,7 @@ def test_reset_entropy_check(session: Session):
assert IF.mnemonic == expected_mnemonic assert IF.mnemonic == expected_mnemonic
# Check that the device is properly initialized. # Check that the device is properly initialized.
if client.protocol_version is ProtocolVersion.PROTOCOL_V1: if session.client.protocol_version is ProtocolVersion.PROTOCOL_V1:
features = session.call_raw(messages.Initialize()) features = session.call_raw(messages.Initialize())
else: else:
session.refresh_features() session.refresh_features()
@ -181,9 +181,9 @@ def test_reset_failed_check(session: Session):
debug = session.client.debug debug = session.client.debug
strength = 256 # 24 words strength = 256 # 24 words
with session.client as client: with session:
IF = InputFlowBip39ResetFailedCheck(client) IF = InputFlowBip39ResetFailedCheck(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# PIN, passphrase, display random # PIN, passphrase, display random
device.setup( device.setup(

View File

@ -47,9 +47,9 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str:
with session.client as client: with session:
IF = InputFlowBip39ResetBackup(client) IF = InputFlowBip39ResetBackup(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.setup( device.setup(
@ -77,10 +77,10 @@ def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> s
def recover(session: Session, mnemonic: str): def recover(session: Session, mnemonic: str):
words = mnemonic.split(" ") words = mnemonic.split(" ")
with session.client as client: with session:
IF = InputFlowBip39Recovery(client, words) IF = InputFlowBip39Recovery(session.client, words)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
client.watch_layout() session.client.watch_layout()
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended

View File

@ -68,9 +68,9 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128) -> list[str]: def reset(session: Session, strength: int = 128) -> list[str]:
with session.client as client: with session:
IF = InputFlowSlip39AdvancedResetRecovery(client, False) IF = InputFlowSlip39AdvancedResetRecovery(session.client, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.setup( device.setup(
@ -97,9 +97,9 @@ def reset(session: Session, strength: int = 128) -> list[str]:
def recover(session: Session, shares: list[str]): def recover(session: Session, shares: list[str]):
with session.client as client: with session:
IF = InputFlowSlip39AdvancedRecovery(client, shares, False) IF = InputFlowSlip39AdvancedRecovery(session.client, shares, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended

View File

@ -58,9 +58,9 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128) -> list[str]: def reset(session: Session, strength: int = 128) -> list[str]:
with session.client as client: with session:
IF = InputFlowSlip39BasicResetRecovery(client) IF = InputFlowSlip39BasicResetRecovery(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.setup( device.setup(
@ -87,9 +87,9 @@ def reset(session: Session, strength: int = 128) -> list[str]:
def recover(session: Session, shares: t.Sequence[str]): def recover(session: Session, shares: t.Sequence[str]):
with session.client as client: with session:
IF = InputFlowSlip39BasicRecovery(client, shares) IF = InputFlowSlip39BasicRecovery(session.client, shares)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended

View File

@ -34,10 +34,10 @@ def test_reset_device_slip39_advanced(client: Client):
strength = 128 strength = 128
member_threshold = 3 member_threshold = 3
with client: session = client.get_seedless_session()
with session:
IF = InputFlowSlip39AdvancedResetRecovery(client, False) IF = InputFlowSlip39AdvancedResetRecovery(client, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
session = client.get_seedless_session()
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.setup( device.setup(
session, session,

View File

@ -34,9 +34,9 @@ pytestmark = pytest.mark.models("core")
def reset_device(session: Session, strength: int): def reset_device(session: Session, strength: int):
member_threshold = 3 member_threshold = 3
with session.client as client: with session:
IF = InputFlowSlip39BasicResetRecovery(client) IF = InputFlowSlip39BasicResetRecovery(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.setup( device.setup(
@ -89,9 +89,9 @@ def test_reset_entropy_check(session: Session):
strength = 128 # 20 words strength = 128 # 20 words
with session.client as client: with session:
IF = InputFlowSlip39BasicResetRecovery(client) IF = InputFlowSlip39BasicResetRecovery(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# No PIN, no passphrase. # No PIN, no passphrase.
path_xpubs = device.setup( path_xpubs = device.setup(

View File

@ -52,9 +52,9 @@ def test_ripple_get_address(session: Session, path: str, expected_address: str):
def test_ripple_get_address_chunkify_details( def test_ripple_get_address_chunkify_details(
session: Session, path: str, expected_address: str session: Session, path: str, expected_address: str
): ):
with session.client as client: with session:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address = get_address( address = get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True
) )

View File

@ -47,9 +47,9 @@ pytestmark = [
def test_solana_sign_tx(session: Session, parameters, result): def test_solana_sign_tx(session: Session, parameters, result):
serialized_tx = _serialize_tx(parameters["construct"]) serialized_tx = _serialize_tx(parameters["construct"])
with session.client as client: with session:
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
actual_result = sign_tx( actual_result = sign_tx(
session, session,
address_n=parse_path(parameters["address"]), address_n=parse_path(parameters["address"]),

View File

@ -122,9 +122,9 @@ def test_get_address(session: Session, parameters, result):
@pytest.mark.models("core") @pytest.mark.models("core")
@parametrize_using_common_fixtures("stellar/get_address.json") @parametrize_using_common_fixtures("stellar/get_address.json")
def test_get_address_chunkify_details(session: Session, parameters, result): def test_get_address_chunkify_details(session: Session, parameters, result):
with session.client as client: with session:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address_n = parse_path(parameters["path"]) address_n = parse_path(parameters["path"])
address = stellar.get_address( address = stellar.get_address(
session, address_n, show_display=True, chunkify=True session, address_n, show_display=True, chunkify=True

View File

@ -38,8 +38,8 @@ def pin_request(session: Session):
def set_autolock_delay(session: Session, delay): def set_autolock_delay(session: Session, delay):
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
pin_request(session), pin_request(session),
@ -61,8 +61,8 @@ def test_apply_auto_lock_delay(session: Session):
get_test_address(session) get_test_address(session)
time.sleep(10.5) # sleep more than auto-lock delay time.sleep(10.5) # sleep more than auto-lock delay
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses([pin_request(session), messages.Address]) session.set_expected_responses([pin_request(session), messages.Address])
get_test_address(session) get_test_address(session)
@ -85,8 +85,8 @@ def test_apply_auto_lock_delay_valid(session: Session, seconds):
def test_autolock_default_value(session: Session): def test_autolock_default_value(session: Session):
assert session.features.auto_lock_delay_ms is None assert session.features.auto_lock_delay_ms is None
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
device.apply_settings(session, label="pls unlock") device.apply_settings(session, label="pls unlock")
session.refresh_features() session.refresh_features()
assert session.features.auto_lock_delay_ms == 60 * 10 * 1000 assert session.features.auto_lock_delay_ms == 60 * 10 * 1000
@ -98,8 +98,8 @@ def test_autolock_default_value(session: Session):
) )
def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): def test_apply_auto_lock_delay_out_of_range(session: Session, seconds):
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
pin_request(session), pin_request(session),

View File

@ -48,8 +48,8 @@ def test_busy_state(session: Session):
_assert_busy(session, True) _assert_busy(session, True)
assert session.features.unlocked is False assert session.features.unlocked is False
with session.client as client: with session:
client.use_pin_sequence([PIN]) session.client.use_pin_sequence([PIN])
btc.get_address( btc.get_address(
session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True
) )

View File

@ -40,9 +40,9 @@ def test_cancel_message_via_cancel(session: Session, message):
yield yield
session.cancel() session.cancel()
with session, session.client as client, pytest.raises(Cancelled): with session, pytest.raises(Cancelled):
session.set_expected_responses([m.ButtonRequest(), m.Failure()]) session.set_expected_responses([m.ButtonRequest(), m.Failure()])
client.set_input_flow(input_flow) session.set_input_flow(input_flow)
session.call(message) session.call(message)

View File

@ -47,12 +47,12 @@ def test_pin(session: Session):
) )
assert isinstance(resp, messages.PinMatrixRequest) assert isinstance(resp, messages.PinMatrixRequest)
with session.client as client: with session:
state = client.debug.state() state = session.client.debug.state()
assert state.pin == "1234" assert state.pin == "1234"
assert state.matrix != "" assert state.matrix != ""
pin_encoded = client.debug.encode_pin("1234") pin_encoded = session.client.debug.encode_pin("1234")
resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded)) resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded))
assert isinstance(resp, messages.PassphraseRequest) assert isinstance(resp, messages.PassphraseRequest)

View File

@ -79,9 +79,9 @@ def _check_ping_screen_texts(session: Session, title: str, right_button: str) ->
if session.model in (models.T2T1, models.T3T1): if session.model in (models.T2T1, models.T3T1):
right_button = "-" right_button = "-"
with session, session.client as client: with session:
client.watch_layout(True) session.client.watch_layout(True)
client.set_input_flow(ping_input_flow(session, title, right_button)) session.set_input_flow(ping_input_flow(session, title, right_button))
ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) ping = session.call(messages.Ping(message="ahoj!", button_protection=True))
assert ping == messages.Success(message="ahoj!") assert ping == messages.Success(message="ahoj!")
@ -274,8 +274,8 @@ def test_reject_update(session: Session):
yield yield
session.client.debug.press_no() session.client.debug.press_no()
with pytest.raises(exceptions.Cancelled), session, session.client as client: with pytest.raises(exceptions.Cancelled), session:
client.set_input_flow(input_flow_reject) session.set_input_flow(input_flow_reject)
device.change_language(session, language_data) device.change_language(session, language_data)
assert session.features.language == "en-US" assert session.features.language == "en-US"

View File

@ -345,12 +345,12 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address] [messages.ButtonRequest, messages.ButtonRequest, messages.Address]
) )
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
get_bad_address() get_bad_address()
with session: with session:
@ -371,13 +371,13 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address] [messages.ButtonRequest, messages.ButtonRequest, messages.Address]
) )
if session.model is not models.T1B1: if session.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
get_bad_address() get_bad_address()
@ -412,8 +412,8 @@ def test_experimental_features(session: Session):
# relock and try again # relock and try again
session.lock() session.lock()
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses([messages.ButtonRequest, messages.Nonce]) session.set_expected_responses([messages.ButtonRequest, messages.Nonce])
experimental_call() experimental_call()

View File

@ -44,9 +44,9 @@ from ..input_flows import (
def test_backup_bip39(session: Session): def test_backup_bip39(session: Session):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session:
IF = InputFlowBip39Backup(client) IF = InputFlowBip39Backup(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
assert IF.mnemonic == MNEMONIC12 assert IF.mnemonic == MNEMONIC12
@ -71,9 +71,9 @@ def test_backup_slip39_basic(session: Session, click_info: bool):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session:
IF = InputFlowSlip39BasicBackup(client, click_info) IF = InputFlowSlip39BasicBackup(session.client, click_info)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
session.refresh_features() session.refresh_features()
@ -95,11 +95,12 @@ def test_backup_slip39_basic(session: Session, click_info: bool):
def test_backup_slip39_single(session: Session): def test_backup_slip39_single(session: Session):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session:
IF = InputFlowBip39Backup( IF = InputFlowBip39Backup(
client, confirm_success=(client.layout_type is not LayoutType.Delizia) session.client,
confirm_success=(session.client.layout_type is not LayoutType.Delizia),
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
assert session.features.initialized is True assert session.features.initialized is True
@ -126,9 +127,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session:
IF = InputFlowSlip39AdvancedBackup(client, click_info) IF = InputFlowSlip39AdvancedBackup(session.client, click_info)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
session.refresh_features() session.refresh_features()
@ -157,9 +158,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool):
def test_backup_slip39_custom(session: Session, share_threshold, share_count): def test_backup_slip39_custom(session: Session, share_threshold, share_count):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session:
IF = InputFlowSlip39CustomBackup(client, share_count) IF = InputFlowSlip39CustomBackup(session.client, share_count)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup( device.backup(
session, group_threshold=1, groups=[(share_threshold, share_count)] session, group_threshold=1, groups=[(share_threshold, share_count)]
) )

View File

@ -34,7 +34,7 @@ pytestmark = pytest.mark.models("legacy")
def _set_wipe_code(session: Session, pin, wipe_code): def _set_wipe_code(session: Session, pin, wipe_code):
# Set/change wipe code. # Set/change wipe code.
with session.client as client, session: with session:
if session.features.pin_protection: if session.features.pin_protection:
pins = [pin, wipe_code, wipe_code] pins = [pin, wipe_code, wipe_code]
pin_matrices = [ pin_matrices = [
@ -49,7 +49,7 @@ def _set_wipe_code(session: Session, pin, wipe_code):
messages.PinMatrixRequest(type=PinType.WipeCodeSecond), messages.PinMatrixRequest(type=PinType.WipeCodeSecond),
] ]
client.use_pin_sequence(pins) session.client.use_pin_sequence(pins)
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest()] + pin_matrices + [messages.Success] [messages.ButtonRequest()] + pin_matrices + [messages.Success]
) )
@ -58,8 +58,8 @@ def _set_wipe_code(session: Session, pin, wipe_code):
def _change_pin(session: Session, old_pin, new_pin): def _change_pin(session: Session, old_pin, new_pin):
assert session.features.pin_protection is True assert session.features.pin_protection is True
with session.client as client: with session:
client.use_pin_sequence([old_pin, new_pin, new_pin]) session.client.use_pin_sequence([old_pin, new_pin, new_pin])
try: try:
return device.change_pin(session) return device.change_pin(session)
except exceptions.TrezorFailure as f: except exceptions.TrezorFailure as f:
@ -96,8 +96,8 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE6) _check_wipe_code(session, PIN4, WIPE_CODE6)
# Test remove wipe code. # Test remove wipe code.
with session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
device.change_wipe_code(session, remove=True) device.change_wipe_code(session, remove=True)
# Check that there's no wipe code protection now. # Check that there's no wipe code protection now.
@ -111,8 +111,8 @@ def test_set_wipe_code_mismatch(session: Session):
assert session.features.wipe_code_protection is False assert session.features.wipe_code_protection is False
# Let's set a new wipe code. # Let's set a new wipe code.
with session.client as client, session: with session:
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6]) session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),
@ -125,8 +125,8 @@ def test_set_wipe_code_mismatch(session: Session):
device.change_wipe_code(session) device.change_wipe_code(session)
# Check that there is no wipe code protection. # Check that there is no wipe code protection.
client.refresh_features() session.client.refresh_features()
assert client.features.wipe_code_protection is False assert session.client.features.wipe_code_protection is False
@pytest.mark.setup_client(pin=PIN4) @pytest.mark.setup_client(pin=PIN4)
@ -135,8 +135,8 @@ def test_set_wipe_code_to_pin(session: Session):
assert session.features.wipe_code_protection is None assert session.features.wipe_code_protection is None
# Let's try setting the wipe code to the curent PIN value. # Let's try setting the wipe code to the curent PIN value.
with session.client as client, session: with session:
client.use_pin_sequence([PIN4, PIN4]) session.client.use_pin_sequence([PIN4, PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),
@ -149,8 +149,8 @@ def test_set_wipe_code_to_pin(session: Session):
device.change_wipe_code(session) device.change_wipe_code(session)
# Check that there is no wipe code protection. # Check that there is no wipe code protection.
client.refresh_features() session.client.refresh_features()
assert client.features.wipe_code_protection is False assert session.client.features.wipe_code_protection is False
def test_set_pin_to_wipe_code(session: Session): def test_set_pin_to_wipe_code(session: Session):
@ -159,8 +159,8 @@ def test_set_pin_to_wipe_code(session: Session):
_set_wipe_code(session, None, WIPE_CODE4) _set_wipe_code(session, None, WIPE_CODE4)
# Try to set the PIN to the current wipe code value. # Try to set the PIN to the current wipe code value.
with session.client as client, session: with session:
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),

View File

@ -37,8 +37,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str):
assert session.features.wipe_code_protection is True assert session.features.wipe_code_protection is True
# Try to change the PIN to the current wipe code value. The operation should fail. # Try to change the PIN to the current wipe code value. The operation should fail.
with session, session.client as client, pytest.raises(TrezorFailure): with session, pytest.raises(TrezorFailure):
client.use_pin_sequence([pin, wipe_code, wipe_code]) session.client.use_pin_sequence([pin, wipe_code, wipe_code])
if session.client.layout_type is LayoutType.Caesar: if session.client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
@ -51,8 +51,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str):
def _ensure_unlocked(session: Session, pin: str): def _ensure_unlocked(session: Session, pin: str):
with session, session.client as client: with session:
client.use_pin_sequence([pin]) session.client.use_pin_sequence([pin])
btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH) btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
session.refresh_features() session.refresh_features()
@ -71,11 +71,11 @@ def test_set_remove_wipe_code(session: Session):
else: else:
br_count = 5 br_count = 5
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success] [messages.ButtonRequest()] * br_count + [messages.Success]
) )
client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX]) session.client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX])
device.change_wipe_code(session) device.change_wipe_code(session)
# session.init_device() # session.init_device()
@ -83,11 +83,11 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE_MAX) _check_wipe_code(session, PIN4, WIPE_CODE_MAX)
# Test change wipe code. # Test change wipe code.
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success] [messages.ButtonRequest()] * br_count + [messages.Success]
) )
client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6]) session.client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6])
device.change_wipe_code(session) device.change_wipe_code(session)
# session.init_device() # session.init_device()
@ -95,11 +95,11 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE6) _check_wipe_code(session, PIN4, WIPE_CODE6)
# Test remove wipe code. # Test remove wipe code.
with session, session.client as client: with session:
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest()] * 3 + [messages.Success] [messages.ButtonRequest()] * 3 + [messages.Success]
) )
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
device.change_wipe_code(session, remove=True) device.change_wipe_code(session, remove=True)
# session.init_device() # session.init_device()
@ -107,9 +107,11 @@ def test_set_remove_wipe_code(session: Session):
def test_set_wipe_code_mismatch(session: Session): def test_set_wipe_code_mismatch(session: Session):
with session, session.client as client, pytest.raises(TrezorFailure): with session, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(client, WIPE_CODE4, WIPE_CODE6, what="wipe_code") IF = InputFlowNewCodeMismatch(
client.set_input_flow(IF.get()) session.client, WIPE_CODE4, WIPE_CODE6, what="wipe_code"
)
session.set_input_flow(IF.get())
device.change_wipe_code(session) device.change_wipe_code(session)
@ -122,15 +124,15 @@ def test_set_wipe_code_mismatch(session: Session):
def test_set_wipe_code_to_pin(session: Session): def test_set_wipe_code_to_pin(session: Session):
_ensure_unlocked(session, PIN4) _ensure_unlocked(session, PIN4)
with session, session.client as client: with session:
if client.layout_type is LayoutType.Caesar: if session.client.layout_type is LayoutType.Caesar:
br_count = 8 br_count = 8
else: else:
br_count = 7 br_count = 7
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success], [messages.ButtonRequest()] * br_count + [messages.Success],
) )
client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4]) session.client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4])
device.change_wipe_code(session) device.change_wipe_code(session)
# session.init_device() # session.init_device()
@ -140,20 +142,20 @@ def test_set_wipe_code_to_pin(session: Session):
def test_set_pin_to_wipe_code(session: Session): def test_set_pin_to_wipe_code(session: Session):
# Set wipe code. # Set wipe code.
with session, session.client as client: with session:
if client.layout_type is LayoutType.Caesar: if session.client.layout_type is LayoutType.Caesar:
br_count = 5 br_count = 5
else: else:
br_count = 4 br_count = 4
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success] [messages.ButtonRequest()] * br_count + [messages.Success]
) )
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
device.change_wipe_code(session) device.change_wipe_code(session)
# Try to set the PIN to the current wipe code value. # Try to set the PIN to the current wipe code value.
with session, session.client as client, pytest.raises(TrezorFailure): with session, pytest.raises(TrezorFailure):
if client.layout_type is LayoutType.Caesar: if session.client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 4 br_count = 4
@ -161,5 +163,5 @@ def test_set_pin_to_wipe_code(session: Session):
[messages.ButtonRequest()] * br_count [messages.ButtonRequest()] * br_count
+ [messages.Failure(code=messages.FailureType.PinInvalid)] + [messages.Failure(code=messages.FailureType.PinInvalid)]
) )
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
device.change_pin(session) device.change_pin(session)

View File

@ -33,8 +33,8 @@ pytestmark = pytest.mark.models("legacy")
def _check_pin(session: Session, pin): def _check_pin(session: Session, pin):
session.lock() session.lock()
with session, session.client as client: with session:
client.use_pin_sequence([pin]) session.client.use_pin_sequence([pin])
session.set_expected_responses([messages.PinMatrixRequest, messages.Address]) session.set_expected_responses([messages.PinMatrixRequest, messages.Address])
get_test_address(session) get_test_address(session)
@ -53,8 +53,8 @@ def test_set_pin(session: Session):
_check_no_pin(session) _check_no_pin(session)
# Let's set new PIN # Let's set new PIN
with session, session.client as client: with session:
client.use_pin_sequence([PIN_MAX, PIN_MAX]) session.client.use_pin_sequence([PIN_MAX, PIN_MAX])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -78,8 +78,8 @@ def test_change_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's change PIN # Let's change PIN
with session, session.client as client: with session:
client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -104,8 +104,8 @@ def test_remove_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's remove PIN # Let's remove PIN
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -126,11 +126,9 @@ def test_set_mismatch(session: Session):
_check_no_pin(session) _check_no_pin(session)
# Let's set new PIN # Let's set new PIN
with session, session.client as client, pytest.raises( with session, pytest.raises(TrezorFailure, match="PIN mismatch"):
TrezorFailure, match="PIN mismatch"
):
# use different PINs for first and second attempt. This will fail. # use different PINs for first and second attempt. This will fail.
client.use_pin_sequence([PIN4, PIN_MAX]) session.client.use_pin_sequence([PIN4, PIN_MAX])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -152,10 +150,8 @@ def test_change_mismatch(session: Session):
assert session.features.pin_protection is True assert session.features.pin_protection is True
# Let's set new PIN # Let's set new PIN
with session, session.client as client, pytest.raises( with session, pytest.raises(TrezorFailure, match="PIN mismatch"):
TrezorFailure, match="PIN mismatch" session.client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"])
):
client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),

View File

@ -37,9 +37,9 @@ pytestmark = pytest.mark.models("core")
def _check_pin(session: Session, pin: str): def _check_pin(session: Session, pin: str):
with session, session.client as client: with session:
client.ui.__init__(client.debug) session.client.ui.__init__(session.client.debug)
client.use_pin_sequence([pin, pin, pin, pin, pin, pin]) session.client.use_pin_sequence([pin, pin, pin, pin, pin, pin])
session.lock() session.lock()
assert session.features.pin_protection is True assert session.features.pin_protection is True
assert session.features.unlocked is False assert session.features.unlocked is False
@ -63,12 +63,12 @@ def test_set_pin(session: Session):
_check_no_pin(session) _check_no_pin(session)
# Let's set new PIN # Let's set new PIN
with session, session.client as client: with session:
if client.layout_type is LayoutType.Caesar: if session.client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 4 br_count = 4
client.use_pin_sequence([PIN_MAX, PIN_MAX]) session.client.use_pin_sequence([PIN_MAX, PIN_MAX])
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest] * br_count + [messages.Success] [messages.ButtonRequest] * br_count + [messages.Success]
) )
@ -86,9 +86,9 @@ def test_change_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's change PIN # Let's change PIN
with session, session.client as client: with session:
client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
if client.layout_type is LayoutType.Caesar: if session.client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 5 br_count = 5
@ -113,8 +113,8 @@ def test_remove_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's remove PIN # Let's remove PIN
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( session.set_expected_responses(
[messages.ButtonRequest] * 3 + [messages.Success] [messages.ButtonRequest] * 3 + [messages.Success]
) )
@ -132,9 +132,9 @@ def test_set_failed(session: Session):
# Check that there's no PIN protection # Check that there's no PIN protection
_check_no_pin(session) _check_no_pin(session)
with session, session.client as client, pytest.raises(TrezorFailure): with session, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(client, PIN4, PIN60, what="pin") IF = InputFlowNewCodeMismatch(session.client, PIN4, PIN60, what="pin")
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.change_pin(session) device.change_pin(session)
@ -151,9 +151,9 @@ def test_change_failed(session: Session):
# Check current PIN value # Check current PIN value
_check_pin(session, PIN4) _check_pin(session, PIN4)
with session, session.client as client, pytest.raises(Cancelled): with session, pytest.raises(Cancelled):
IF = InputFlowCodeChangeFail(session, PIN4, "457891", "381847") IF = InputFlowCodeChangeFail(session, PIN4, "457891", "381847")
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.change_pin(session) device.change_pin(session)
@ -170,9 +170,9 @@ def test_change_invalid_current(session: Session):
# Check current PIN value # Check current PIN value
_check_pin(session, PIN4) _check_pin(session, PIN4)
with session, session.client as client, pytest.raises(TrezorFailure): with session, pytest.raises(TrezorFailure):
IF = InputFlowWrongPIN(client, PIN60) IF = InputFlowWrongPIN(session.client, PIN60)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.change_pin(session) device.change_pin(session)
@ -200,7 +200,7 @@ def test_pin_menu_cancel_setup(session: Session):
# tap to confirm # tap to confirm
debug.click(debug.screen_buttons.tap_to_confirm()) debug.click(debug.screen_buttons.tap_to_confirm())
with session, session.client as client, pytest.raises(Cancelled): with session, pytest.raises(Cancelled):
client.set_input_flow(cancel_pin_setup_input_flow) session.set_input_flow(cancel_pin_setup_input_flow)
session.call(messages.ChangePin()) session.call(messages.ChangePin())
_check_no_pin(session) _check_no_pin(session)

View File

@ -45,9 +45,8 @@ def test_wipe_device(client: Client):
@pytest.mark.setup_client(pin=PIN4) @pytest.mark.setup_client(pin=PIN4)
def test_autolock_not_retained(session: Session): def test_autolock_not_retained(session: Session):
client = session.client client = session.client
with client: client.use_pin_sequence([PIN4])
client.use_pin_sequence([PIN4]) device.apply_settings(session, auto_lock_delay_ms=10_000)
device.apply_settings(session, auto_lock_delay_ms=10_000)
assert session.features.auto_lock_delay_ms == 10_000 assert session.features.auto_lock_delay_ms == 10_000
@ -57,21 +56,20 @@ def test_autolock_not_retained(session: Session):
assert client.features.auto_lock_delay_ms > 10_000 assert client.features.auto_lock_delay_ms > 10_000
with client: client.use_pin_sequence([PIN4, PIN4])
client.use_pin_sequence([PIN4, PIN4]) device.setup(
device.setup( session,
session, skip_backup=True,
skip_backup=True, pin_protection=True,
pin_protection=True, passphrase_protection=False,
passphrase_protection=False, entropy_check_count=0,
entropy_check_count=0, backup_type=messages.BackupType.Bip39,
backup_type=messages.BackupType.Bip39, )
)
time.sleep(10.5) time.sleep(10.5)
session = client.get_session() session = client.get_session()
with session, client: with session:
# after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked
session.set_expected_responses([messages.Address]) session.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)

View File

@ -39,8 +39,8 @@ def test_no_protection(session: Session):
def test_correct_pin(session: Session): def test_correct_pin(session: Session):
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
# Expected responses differ between T1 and TT # Expected responses differ between T1 and TT
is_t1 = session.model is models.T1B1 is_t1 = session.model is models.T1B1
session.set_expected_responses( session.set_expected_responses(
@ -65,9 +65,9 @@ def test_incorrect_pin_t1(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_incorrect_pin_t2(session: Session): def test_incorrect_pin_t2(session: Session):
with session, session.client as client: with session:
# After first incorrect attempt, TT will not raise an error, but instead ask for another attempt # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt
client.use_pin_sequence([BAD_PIN, PIN4]) session.client.use_pin_sequence([BAD_PIN, PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry),
@ -82,15 +82,15 @@ def test_incorrect_pin_t2(session: Session):
def test_exponential_backoff_t1(session: Session): def test_exponential_backoff_t1(session: Session):
for attempt in range(3): for attempt in range(3):
start = time.time() start = time.time()
with session, session.client as client, pytest.raises(PinException): with session, pytest.raises(PinException):
client.use_pin_sequence([BAD_PIN]) session.client.use_pin_sequence([BAD_PIN])
get_test_address(session) get_test_address(session)
check_pin_backoff_time(attempt, start) check_pin_backoff_time(attempt, start)
@pytest.mark.models("core") @pytest.mark.models("core")
def test_exponential_backoff_t2(session: Session): def test_exponential_backoff_t2(session: Session):
with session.client as client: with session:
IF = InputFlowPINBackoff(client, BAD_PIN, PIN4) IF = InputFlowPINBackoff(session.client, BAD_PIN, PIN4)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
get_test_address(session) get_test_address(session)

View File

@ -56,12 +56,12 @@ def _assert_protection(
session: Session, pin: bool = True, passphrase: bool = True session: Session, pin: bool = True, passphrase: bool = True
) -> Session: ) -> Session:
"""Make sure PIN and passphrase protection have expected values""" """Make sure PIN and passphrase protection have expected values"""
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.ensure_unlocked() session.ensure_unlocked()
client.refresh_features() session.client.refresh_features()
assert client.features.pin_protection is pin assert session.client.features.pin_protection is pin
assert client.features.passphrase_protection is passphrase assert session.client.features.passphrase_protection is passphrase
session.lock() session.lock()
# session.end() # session.end()
if session.protocol_version == ProtocolVersion.PROTOCOL_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
@ -70,8 +70,8 @@ def _assert_protection(
def test_initialize(session: Session): def test_initialize(session: Session):
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.ensure_unlocked() session.ensure_unlocked()
session = _assert_protection(session) session = _assert_protection(session)
with session: with session:
@ -86,8 +86,8 @@ def test_passphrase_reporting(session: Session, passphrase):
"""On TT, passphrase_protection is a private setting, so a locked device should """On TT, passphrase_protection is a private setting, so a locked device should
report passphrase_protection=None. report passphrase_protection=None.
""" """
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
device.apply_settings(session, use_passphrase=passphrase) device.apply_settings(session, use_passphrase=passphrase)
session.lock() session.lock()
@ -108,8 +108,8 @@ def test_passphrase_reporting(session: Session, passphrase):
def test_apply_settings(session: Session): def test_apply_settings(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
_pin_request(session), _pin_request(session),
@ -124,8 +124,8 @@ def test_apply_settings(session: Session):
@pytest.mark.models("legacy") @pytest.mark.models("legacy")
def test_change_pin_t1(session: Session): def test_change_pin_t1(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4, PIN4, PIN4]) session.client.use_pin_sequence([PIN4, PIN4, PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
messages.ButtonRequest, messages.ButtonRequest,
@ -141,8 +141,8 @@ def test_change_pin_t1(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_change_pin_t2(session: Session): def test_change_pin_t2(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) session.client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
_pin_request(session), _pin_request(session),
@ -172,8 +172,8 @@ def test_ping(session: Session):
def test_get_entropy(session: Session): def test_get_entropy(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
_pin_request(session), _pin_request(session),
@ -187,8 +187,8 @@ def test_get_entropy(session: Session):
def test_get_public_key(session: Session): def test_get_public_key(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
if session.protocol_version == ProtocolVersion.PROTOCOL_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
@ -202,8 +202,8 @@ def test_get_public_key(session: Session):
def test_get_address(session: Session): def test_get_address(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
if session.protocol_version == ProtocolVersion.PROTOCOL_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PassphraseRequest)
@ -221,8 +221,8 @@ def test_wipe_device(session: Session):
device.wipe(session) device.wipe(session)
client = session.client.get_new_client() client = session.client.get_new_client()
session = client.get_seedless_session() session = client.get_seedless_session()
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses([messages.Features]) session.set_expected_responses([messages.Features])
session.call(messages.GetFeatures()) session.call(messages.GetFeatures())
@ -301,8 +301,8 @@ def test_recovery_device(session: Session):
def test_sign_message(session: Session): def test_sign_message(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
@ -350,8 +350,8 @@ def test_verify_message_t1(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_verify_message_t2(session: Session): def test_verify_message_t2(session: Session):
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( session.set_expected_responses(
[ [
_pin_request(session), _pin_request(session),
@ -389,8 +389,8 @@ def test_signtx(session: Session):
) )
session = _assert_protection(session) session = _assert_protection(session)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
if session.protocol_version == ProtocolVersion.PROTOCOL_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PassphraseRequest)
@ -430,8 +430,8 @@ def test_unlocked(session: Session):
session = _assert_protection(session, passphrase=False) session = _assert_protection(session, passphrase=False)
with session, session.client as client: with session:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses([_pin_request(session), messages.Address]) session.set_expected_responses([_pin_request(session), messages.Address])
get_test_address(session) get_test_address(session)

View File

@ -39,9 +39,9 @@ def test_repeated_backup(session: Session):
# initial device backup # initial device backup
mnemonics = [] mnemonics = []
with session, session.client as client: with session:
IF = InputFlowSlip39BasicBackup(client, False) IF = InputFlowSlip39BasicBackup(session.client, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
mnemonics = IF.mnemonics mnemonics = IF.mnemonics
@ -56,11 +56,11 @@ def test_repeated_backup(session: Session):
device.backup(session) device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got # unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True session.client, mnemonics[:3], unlock_repeated_backup=True
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert ( assert (
session.features.backup_availability session.features.backup_availability
@ -69,9 +69,9 @@ def test_repeated_backup(session: Session):
assert session.features.recovery_status == messages.RecoveryStatus.Backup assert session.features.recovery_status == messages.RecoveryStatus.Backup
# we can now perform another backup # we can now perform another backup
with session, session.client as client: with session:
IF = InputFlowSlip39BasicBackup(client, False, repeated=True) IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
# the backup feature is locked again... # the backup feature is locked again...
@ -92,11 +92,11 @@ def test_repeated_backup_upgrade_single(session: Session):
assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable
# unlock repeated backup by entering the single share # unlock repeated backup by entering the single share
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True session.client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert ( assert (
session.features.backup_availability session.features.backup_availability
@ -105,9 +105,9 @@ def test_repeated_backup_upgrade_single(session: Session):
assert session.features.recovery_status == messages.RecoveryStatus.Backup assert session.features.recovery_status == messages.RecoveryStatus.Backup
# we can now perform another backup # we can now perform another backup
with session, session.client as client: with session:
IF = InputFlowSlip39BasicBackup(client, False, repeated=True) IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
# backup type was upgraded: # backup type was upgraded:
@ -128,9 +128,9 @@ def test_repeated_backup_cancel(session: Session):
# initial device backup # initial device backup
mnemonics = [] mnemonics = []
with session, session.client as client: with session:
IF = InputFlowSlip39BasicBackup(client, False) IF = InputFlowSlip39BasicBackup(session.client, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
mnemonics = IF.mnemonics mnemonics = IF.mnemonics
@ -145,11 +145,11 @@ def test_repeated_backup_cancel(session: Session):
device.backup(session) device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got # unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True session.client, mnemonics[:3], unlock_repeated_backup=True
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert ( assert (
session.features.backup_availability session.features.backup_availability
@ -183,9 +183,9 @@ def test_repeated_backup_send_disallowed_message(session: Session):
# initial device backup # initial device backup
mnemonics = [] mnemonics = []
with session, session.client as client: with session:
IF = InputFlowSlip39BasicBackup(client, False) IF = InputFlowSlip39BasicBackup(session.client, False)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(session) device.backup(session)
mnemonics = IF.mnemonics mnemonics = IF.mnemonics
@ -200,11 +200,11 @@ def test_repeated_backup_send_disallowed_message(session: Session):
device.backup(session) device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got # unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client: with session:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True session.client, mnemonics[:3], unlock_repeated_backup=True
) )
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup) device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert ( assert (
session.features.backup_availability session.features.backup_availability

View File

@ -45,8 +45,8 @@ def test_sd_no_format(session: Session):
yield # format SD card yield # format SD card
debug.press_no() debug.press_no()
with session, session.client as client, pytest.raises(TrezorFailure) as e: with session, pytest.raises(TrezorFailure) as e:
client.set_input_flow(input_flow) session.set_input_flow(input_flow)
device.sd_protect(session, Op.ENABLE) device.sd_protect(session, Op.ENABLE)
assert e.value.code == messages.FailureType.ProcessError assert e.value.code == messages.FailureType.ProcessError
@ -76,9 +76,9 @@ def test_sd_protect_unlock(session: Session):
assert TR.sd_card__enabled in layout().text_content() assert TR.sd_card__enabled in layout().text_content()
debug.press_yes() debug.press_yes()
with session, session.client as client: with session:
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow_enable_sd_protect) session.set_input_flow(input_flow_enable_sd_protect)
device.sd_protect(session, Op.ENABLE) device.sd_protect(session, Op.ENABLE)
def input_flow_change_pin(): def input_flow_change_pin():
@ -102,9 +102,9 @@ def test_sd_protect_unlock(session: Session):
assert TR.pin__changed in layout().text_content() assert TR.pin__changed in layout().text_content()
debug.press_yes() debug.press_yes()
with session, session.client as client: with session:
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow_change_pin) session.set_input_flow(input_flow_change_pin)
device.change_pin(session) device.change_pin(session)
debug.erase_sd_card(format=False) debug.erase_sd_card(format=False)
@ -125,9 +125,9 @@ def test_sd_protect_unlock(session: Session):
) )
debug.press_no() # close debug.press_no() # close
with session, session.client as client, pytest.raises(TrezorFailure) as e: with session, pytest.raises(TrezorFailure) as e:
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow_change_pin_format) session.set_input_flow(input_flow_change_pin_format)
device.change_pin(session) device.change_pin(session)
assert e.value.code == messages.FailureType.ProcessError assert e.value.code == messages.FailureType.ProcessError

View File

@ -41,7 +41,7 @@ def test_clear_session(client: Client):
cached_responses = [messages.PublicKey] cached_responses = [messages.PublicKey]
session = client.get_session() session = client.get_session()
session.lock() session.lock()
with client, session: with session:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses) session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB assert get_public_node(session, ADDRESS_N).xpub == XPUB
@ -57,7 +57,7 @@ def test_clear_session(client: Client):
session = client.get_session() session = client.get_session()
# session cache is cleared # session cache is cleared
with client, session: with session:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses) session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB assert get_public_node(session, ADDRESS_N).xpub == XPUB
@ -76,7 +76,7 @@ def test_end_session(client: Client):
assert session.id is not None assert session.id is not None
# get_address will succeed # get_address will succeed
with session: with session as session:
session.set_expected_responses([messages.Address]) session.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)
@ -135,7 +135,7 @@ def test_end_session_only_current(client: Client):
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
def test_session_recycling(client: Client): def test_session_recycling(client: Client):
session = client.get_session(passphrase="TREZOR") session = client.get_session(passphrase="TREZOR")
with client, session: with session:
session.set_expected_responses( session.set_expected_responses(
[ [
messages.PassphraseRequest, messages.PassphraseRequest,
@ -152,7 +152,7 @@ def test_session_recycling(client: Client):
session_x.end() session_x.end()
# it should still be possible to resume the original session # it should still be possible to resume the original session
with client, session: with session:
# passphrase should still be cached # passphrase should still be cached
session.set_expected_responses([messages.Address] * 3) session.set_expected_responses([messages.Address] * 3)
client.resume_session(session) client.resume_session(session)

View File

@ -396,7 +396,7 @@ def test_passphrase_length(client: Client):
def test_hide_passphrase_from_host(client: Client): def test_hide_passphrase_from_host(client: Client):
# Without safety checks, turning it on fails # Without safety checks, turning it on fails
session = client.get_seedless_session() session = client.get_seedless_session()
with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: with pytest.raises(TrezorFailure, match="Safety checks are strict"):
device.apply_settings(session, hide_passphrase_from_host=True) device.apply_settings(session, hide_passphrase_from_host=True)
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
@ -406,7 +406,7 @@ def test_hide_passphrase_from_host(client: Client):
passphrase = "abc" passphrase = "abc"
session = client.get_session(passphrase=passphrase) session = client.get_session(passphrase=passphrase)
with client, session: with session:
def input_flow(): def input_flow():
yield yield
@ -421,8 +421,8 @@ def test_hide_passphrase_from_host(client: Client):
else: else:
raise KeyError raise KeyError
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow) session.set_input_flow(input_flow)
session.set_expected_responses( session.set_expected_responses(
[ [
messages.PassphraseRequest, messages.PassphraseRequest,
@ -440,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client):
# Starting new session, otherwise the passphrase would be cached # Starting new session, otherwise the passphrase would be cached
session = client.get_session(passphrase=passphrase) session = client.get_session(passphrase=passphrase)
with client, session: with session:
def input_flow(): def input_flow():
yield yield
@ -455,8 +455,8 @@ def test_hide_passphrase_from_host(client: Client):
assert passphrase in client.debug.read_layout().text_content() assert passphrase in client.debug.read_layout().text_content()
client.debug.press_yes() client.debug.press_yes()
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow) session.set_input_flow(input_flow)
session.set_expected_responses( session.set_expected_responses(
[ [
messages.PassphraseRequest, messages.PassphraseRequest,

View File

@ -44,9 +44,9 @@ def test_tezos_get_address(session: Session, path: str, expected_address: str):
def test_tezos_get_address_chunkify_details( def test_tezos_get_address_chunkify_details(
session: Session, path: str, expected_address: str session: Session, path: str, expected_address: str
): ):
with session.client as client: with session:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
address = get_address( address = get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True
) )

View File

@ -31,9 +31,9 @@ RK_CAPACITY = 100
@pytest.mark.altcoin @pytest.mark.altcoin
@pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_add_remove(session: Session): def test_add_remove(session: Session):
with session, session.client as client: with session:
IF = InputFlowFidoConfirm(client) IF = InputFlowFidoConfirm(session.client)
client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
# Remove index 0 should fail. # Remove index 0 should fail.
with pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure):

View File

@ -50,16 +50,18 @@ class InputFlowBase:
# There could be one common input flow for all models # There could be one common input flow for all models
if hasattr(self, "input_flow_common"): if hasattr(self, "input_flow_common"):
return getattr(self, "input_flow_common") flow = getattr(self, "input_flow_common")
elif self.client.layout_type is LayoutType.Bolt: elif self.client.layout_type is LayoutType.Bolt:
return self.input_flow_bolt flow = self.input_flow_bolt
elif self.client.layout_type is LayoutType.Caesar: elif self.client.layout_type is LayoutType.Caesar:
return self.input_flow_caesar flow = self.input_flow_caesar
elif self.client.layout_type is LayoutType.Delizia: elif self.client.layout_type is LayoutType.Delizia:
return self.input_flow_delizia flow = self.input_flow_delizia
else: else:
raise ValueError("Unknown model") raise ValueError("Unknown model")
return flow
def input_flow_bolt(self) -> BRGeneratorType: def input_flow_bolt(self) -> BRGeneratorType:
"""Special for TT""" """Special for TT"""
raise NotImplementedError raise NotImplementedError
@ -371,7 +373,7 @@ class InputFlowSignMessageInfo(InputFlowBase):
self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1]) self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1])
# address mismatch? yes! # address mismatch? yes!
self.debug.swipe_up() self.debug.swipe_up()
yield yield # ?
class InputFlowShowAddressQRCode(InputFlowBase): class InputFlowShowAddressQRCode(InputFlowBase):

View File

@ -11,33 +11,37 @@ WIPE_CODE = "9876"
def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session()) session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client() client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device( debuglink.load_device(
client.get_seedless_session(), session,
MNEMONIC12, MNEMONIC12,
pin, pin,
passphrase_protection=False, passphrase_protection=False,
label="WIPECODE", label="WIPECODE",
) )
with client: with session:
client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
device.change_wipe_code(client.get_seedless_session()) device.change_wipe_code(client.get_seedless_session())
def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: def setup_device_core(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session()) session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client() client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device( debuglink.load_device(
client.get_seedless_session(), session,
MNEMONIC12, MNEMONIC12,
pin, pin,
passphrase_protection=False, passphrase_protection=False,
label="WIPECODE", label="WIPECODE",
) )
with client: with session:
client.use_pin_sequence([pin, wipe_code, wipe_code]) client.use_pin_sequence([pin, wipe_code, wipe_code])
device.change_wipe_code(client.get_seedless_session()) device.change_wipe_code(client.get_seedless_session())

View File

@ -96,7 +96,7 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None:
assert client.features.initialized assert client.features.initialized
assert client.features.label == LABEL assert client.features.label == LABEL
session = client.get_session() session = client.get_session()
with client, session: with session:
client.use_pin_sequence([PIN]) client.use_pin_sequence([PIN])
assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS
@ -395,10 +395,11 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
# Create a backup of the encrypted master secret. # Create a backup of the encrypted master secret.
assert emu.client.features.backup_availability == BackupAvailability.Required assert emu.client.features.backup_availability == BackupAvailability.Required
with emu.client: session = emu.client.get_session()
with session:
IF = InputFlowSlip39BasicBackup(emu.client, False) IF = InputFlowSlip39BasicBackup(emu.client, False)
emu.client.set_input_flow(IF.get()) session.set_input_flow(IF.get())
device.backup(emu.client.get_session()) device.backup(session)
assert ( assert (
emu.client.features.backup_availability == BackupAvailability.NotAvailable emu.client.features.backup_availability == BackupAvailability.NotAvailable
) )