1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-15 23:56:12 +00:00

refactor(tests): move set_input_flow to SessionDebugWrapper context manager

[no changelog]
This commit is contained in:
Martin Milata 2025-03-06 00:35:51 +01:00 committed by M1nd3r
parent 327fe1f98b
commit 46ac326354
72 changed files with 632 additions and 668 deletions

View File

@ -789,10 +789,10 @@ class DebugUI:
def __init__(self, debuglink: DebugLink) -> None:
self.debuglink = debuglink
self.pins: t.Iterator[str] | None = None
self.clear()
def clear(self) -> None:
self.pins: t.Iterator[str] | None = None
self.passphrase = None
self.input_flow: t.Union[
t.Generator[None, messages.ButtonRequest, None], object, None
@ -964,12 +964,10 @@ class SessionDebugWrapper(Session):
return self.client.protocol_version
def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__)
self._session._write(self._filter_message(msg))
def _read(self) -> t.Any:
resp = self._filter_message(self._session._read())
print("reading message:", resp.__class__.__name__)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
@ -1067,6 +1065,7 @@ class SessionDebugWrapper(Session):
Clears all debugging state that might have been modified by a testcase.
"""
self.client.ui.clear() # type: ignore [Cannot access attribute]
self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None
@ -1103,7 +1102,6 @@ class SessionDebugWrapper(Session):
# If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses)
elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
@ -1163,6 +1161,45 @@ class SessionDebugWrapper(Session):
output.append("")
return output
def set_input_flow(
self,
input_flow: InputFlowType | t.Callable[[], InputFlowType],
) -> None:
"""Configure a sequence of input events for the current with-block.
The `input_flow` must be a generator function. A `yield` statement in the
input flow function waits for a ButtonRequest from the device, and returns
its code.
Example usage:
>>> def input_flow():
>>> # wait for first button prompt
>>> code = yield
>>> assert code == ButtonRequestType.Other
>>> # press No
>>> client.debug.press_no()
>>>
>>> # wait for second button prompt
>>> yield
>>> # press Yes
>>> client.debug.press_yes()
>>>
>>> with session:
>>> session.set_input_flow(input_flow)
>>> some_call(session)
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
if callable(input_flow):
input_flow = input_flow()
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.client.ui.input_flow = input_flow # type: ignore [Cannot access attribute]
next(input_flow) # start the generator
class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses
@ -1204,7 +1241,6 @@ class TrezorClientDebugLink(TrezorClient):
self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
self.reset_debug_features()
self._seedless_session = self.get_seedless_session(new_session=True)
self.sync_responses()
@ -1229,15 +1265,6 @@ class TrezorClientDebugLink(TrezorClient):
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
return new_client
def reset_debug_features(self) -> None:
"""
Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False
def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
@ -1366,43 +1393,6 @@ class TrezorClientDebugLink(TrezorClient):
else:
return SessionDebugWrapper(super().resume_session(session))
def set_input_flow(
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
) -> None:
"""Configure a sequence of input events for the current with-block.
The `input_flow` must be a generator function. A `yield` statement in the
input flow function waits for a ButtonRequest from the device, and returns
its code.
Example usage:
>>> def input_flow():
>>> # wait for first button prompt
>>> code = yield
>>> assert code == ButtonRequestType.Other
>>> # press No
>>> client.debug.press_no()
>>>
>>> # wait for second button prompt
>>> yield
>>> # press Yes
>>> client.debug.press_yes()
>>>
>>> with client:
>>> client.set_input_flow(input_flow)
>>> some_call(client)
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
if callable(input_flow):
input_flow = input_flow()
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow
next(input_flow) # start the generator
def watch_layout(self, watch: bool = True) -> None:
"""Enable or disable watching layout changes.
@ -1416,29 +1406,6 @@ class TrezorClientDebugLink(TrezorClient):
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
self.debug.watch_layout(watch)
def __enter__(self) -> "TrezorClientDebugLink":
# For usage in with/expected_responses
if self.in_with_statement:
raise RuntimeError("Do not nest!")
self.in_with_statement = True
return self
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.ui, DebugUI):
input_flow = self.ui.input_flow
else:
input_flow = None
self.reset_debug_features()
if exc_type is not None and isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts.
@ -1450,25 +1417,6 @@ class TrezorClientDebugLink(TrezorClient):
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: list[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
for i in range(start_at, stop_at):
exp = expected[i]
prefix = " " if i != current else ">>> "
output.append(textwrap.indent(exp.to_string(), prefix))
if stop_at < len(expected):
omitted = len(expected) - stop_at
output.append(f" (...{omitted} following responses omitted)")
output.append("")
return output
def sync_responses(self) -> None:
"""Synchronize Trezor device receiving with caller.

View File

@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str):
if __name__ == "__main__":
wirelink = get_device()
client = Client(wirelink)
client.open()
session = client.get_seedless_session()
i = 0
@ -76,10 +76,12 @@ if __name__ == "__main__":
# change PIN
new_pin = "".join(random.choices(string.digits, k=random.randint(6, 10)))
client.set_input_flow(pin_input_flow(client, last_pin, new_pin))
session.set_input_flow(pin_input_flow(client, last_pin, new_pin))
device.change_pin(client)
client.set_input_flow(None)
session.set_input_flow(None)
last_pin = new_pin
print(f"iteration {i}")
i = i + 1
wirelink.close()

View File

@ -198,7 +198,7 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
session.set_filter(messages.TxAck, None)
return msg
with session, device_handler.client:
with session:
session.set_filter(messages.TxAck, sleepy_filter)
# confirm transaction
if debug.layout_type is LayoutType.Bolt:

View File

@ -283,7 +283,7 @@ def _client_unlocked(
test_ui = request.config.getoption("ui")
_raw_client.reset_debug_features()
# _raw_client.reset_debug_features()
if isinstance(_raw_client.protocol, ProtocolV1Channel):
try:
_raw_client.sync_responses()

View File

@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1
class NullUI:
@staticmethod
def clear(*args, **kwargs):
pass
@staticmethod
def button_request(code):
pass

View File

@ -51,9 +51,9 @@ def test_binance_get_address_chunkify_details(
):
# data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50
with session.client as client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
address = get_address(
session, parse_path(path), show_display=True, chunkify=True
)

View File

@ -32,9 +32,9 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0")
mnemonic="offer caution gift cross surge pretty orange during eye soldier popular holiday mention east eight office fashion ill parrot vault rent devote earth cousin"
)
def test_binance_get_public_key(session: Session):
with session.client as client:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
sig = binance.get_public_key(session, BINANCE_PATH, show_display=True)
assert (
sig.hex()

View File

@ -65,8 +65,8 @@ def test_sign_tx(session: Session, chunkify: bool):
assert session.features.unlocked is False
commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big")
with session.client as client:
client.use_pin_sequence([PIN])
with session:
session.client.use_pin_sequence([PIN])
btc.authorize_coinjoin(
session,
coordinator="www.example.com",

View File

@ -168,9 +168,9 @@ def _address_n(purpose, coin, account, script_type):
def test_descriptors(
session: Session, coin, account, purpose, script_type, descriptors
):
with session.client as client:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
address_n = _address_n(purpose, coin, account, script_type)
res = btc.get_public_node(
@ -191,10 +191,10 @@ def test_descriptors(
def test_descriptors_trezorlib(
session: Session, coin, account, purpose, script_type, descriptors
):
with session.client as client:
if client.model != models.T1B1:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
with session:
if session.client.model != models.T1B1:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
res = btc_cli._get_descriptor(
session, coin, account, purpose, script_type, show_display=True
)

View File

@ -270,10 +270,10 @@ def test_multisig(session: Session):
xpubs.append(node.xpub)
for nr in range(1, 4):
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
session,
@ -321,10 +321,10 @@ def test_multisig_missing(session: Session, show_display):
)
for multisig in (multisig1, multisig2):
with session.client as client, pytest.raises(TrezorFailure):
with pytest.raises(TrezorFailure), session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.get_address(
session,
"Bitcoin",
@ -345,10 +345,10 @@ def test_bch_multisig(session: Session):
xpubs.append(node.xpub)
for nr in range(1, 4):
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
session,
@ -406,7 +406,7 @@ def test_unknown_path(session: Session):
# disable safety checks
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with session, session.client as client:
with session:
session.set_expected_responses(
[
messages.ButtonRequest(
@ -417,8 +417,8 @@ def test_unknown_path(session: Session):
]
)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
# try again with a warning
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True)
@ -455,10 +455,10 @@ def test_multisig_different_paths(session: Session):
with pytest.raises(
Exception, match="Using different paths for different xpubs is not allowed"
):
with session.client as client, session:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.get_address(
session,
"Bitcoin",
@ -469,10 +469,10 @@ def test_multisig_different_paths(session: Session):
)
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.get_address(
session,
"Bitcoin",

View File

@ -74,10 +74,10 @@ def test_show_segwit(session: Session):
@pytest.mark.altcoin
def test_show_segwit_altcoin(session: Session):
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
session,

View File

@ -63,9 +63,9 @@ def test_show_t1(
yield
session.client.debug.press_yes()
with session.client as client:
with session:
# This is the only place where even T1 is using input flow
client.set_input_flow(input_flow_t1)
session.set_input_flow(input_flow_t1)
assert (
btc.get_address(
session,
@ -88,9 +88,9 @@ def test_show_tt(
script_type: messages.InputScriptType,
address: str,
):
with session.client as client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
session,
@ -109,9 +109,9 @@ def test_show_tt(
def test_show_cancel(
session: Session, path: str, script_type: messages.InputScriptType, address: str
):
with session.client as client, pytest.raises(Cancelled):
IF = InputFlowShowAddressQRCodeCancel(client)
client.set_input_flow(IF.get())
with session, pytest.raises(Cancelled):
IF = InputFlowShowAddressQRCodeCancel(session.client)
session.set_input_flow(IF.get())
btc.get_address(
session,
"Bitcoin",
@ -157,10 +157,10 @@ def test_show_multisig_3(session: Session):
for multisig in (multisig1, multisig2):
for i in [1, 2, 3]:
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
session,
@ -273,11 +273,11 @@ def test_show_multisig_xpubs(
)
for i in range(3):
with session, session.client as client:
IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i)
client.set_input_flow(IF.get())
client.debug.synchronize_at("Homescreen")
client.watch_layout()
with session:
IF = InputFlowShowMultisigXPUBs(session.client, address, xpubs, i)
session.set_input_flow(IF.get())
session.client.debug.synchronize_at("Homescreen")
session.client.watch_layout()
btc.get_address(
session,
"Bitcoin",
@ -314,10 +314,10 @@ def test_show_multisig_15(session: Session):
for multisig in [multisig1, multisig2]:
for i in range(15):
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
assert (
btc.get_address(
session,

View File

@ -119,9 +119,9 @@ def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub):
@pytest.mark.models("core")
@pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN)
def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub):
with session.client as client:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub
assert bip32.serialize(res.node, xpub_magic) == xpub
@ -158,14 +158,14 @@ def test_get_public_node_show_legacy(
client.debug.press_yes() # finish the flow
yield
with client:
with session:
# test XPUB display flow (without showing QR code)
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub
assert bip32.serialize(res.node, xpub_magic) == xpub
# test XPUB QR code display using the input flow above
client.set_input_flow(input_flow)
session.set_input_flow(input_flow)
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub
assert bip32.serialize(res.node, xpub_magic) == xpub

View File

@ -475,10 +475,10 @@ def test_attack_change_input(session: Session):
)
# Transaction can be signed without the attack processor
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(
session,
"Testnet",

View File

@ -288,7 +288,7 @@ def test_external_internal(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session, session.client as client:
with session:
session.set_expected_responses(
_responses(
session,
@ -299,8 +299,8 @@ def test_external_internal(session: Session):
)
)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(
session,
"Bitcoin",
@ -324,7 +324,7 @@ def test_internal_external(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session, session.client as client:
with session:
session.set_expected_responses(
_responses(
session,
@ -335,8 +335,8 @@ def test_internal_external(session: Session):
)
)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(
session,
"Bitcoin",

View File

@ -113,10 +113,10 @@ def test_getaddress(
script_types: list[messages.InputScriptType],
):
for script_type in script_types:
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
res = btc.get_address(
session,
"Bitcoin",
@ -134,10 +134,10 @@ def test_signmessage(
session: Session, path: str, script_types: list[messages.InputScriptType]
):
for script_type in script_types:
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
sig = btc.sign_message(
session,
@ -175,10 +175,10 @@ def test_signtx(
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx(
session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}
)
@ -202,10 +202,10 @@ def test_getaddress_multisig(
]
multisig = messages.MultisigRedeemScriptType(pubkeys=pubs, m=2)
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
address = btc.get_address(
session,
"Bitcoin",
@ -261,10 +261,10 @@ def test_signtx_multisig(session: Session, paths: list[str], address_index: list
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
sig, _ = btc.sign_tx(
session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}
)

View File

@ -327,9 +327,9 @@ def test_signmessage_long(
message: str,
signature: str,
):
with session.client as client:
IF = InputFlowSignVerifyMessageLong(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignVerifyMessageLong(session.client)
session.set_input_flow(IF.get())
sig = btc.sign_message(
session,
coin_name=coin_name,
@ -356,9 +356,9 @@ def test_signmessage_info(
message: str,
signature: str,
):
with session.client as client, pytest.raises(Cancelled):
IF = InputFlowSignMessageInfo(client)
client.set_input_flow(IF.get())
with session, pytest.raises(Cancelled):
IF = InputFlowSignMessageInfo(session.client)
session.set_input_flow(IF.get())
sig = btc.sign_message(
session,
coin_name=coin_name,
@ -390,13 +390,13 @@ MESSAGE_LENGTHS = (
@pytest.mark.models("core")
@pytest.mark.parametrize("message,is_long", MESSAGE_LENGTHS)
def test_signmessage_pagination(session: Session, message: str, is_long: bool):
with session.client as client:
with session:
IF = (
InputFlowSignVerifyMessageLong
if is_long
else InputFlowSignMessagePagination
)(client)
client.set_input_flow(IF.get())
)(session.client)
session.set_input_flow(IF.get())
btc.sign_message(
session,
coin_name="Bitcoin",
@ -438,7 +438,7 @@ def test_signmessage_pagination_trailing_newline(session: Session):
def test_signmessage_path_warning(session: Session):
message = "This is an example of a signed message."
with session, session.client as client:
with session:
session.set_expected_responses(
[
# expect a path warning
@ -451,8 +451,8 @@ def test_signmessage_path_warning(session: Session):
]
)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_message(
session,
coin_name="Bitcoin",

View File

@ -664,9 +664,9 @@ def test_fee_high_hardfail(session: Session):
device.apply_settings(
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
)
with session.client as client:
IF = InputFlowSignTxHighFee(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignTxHighFee(session.client)
session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_CACHE_TESTNET
@ -1467,9 +1467,9 @@ def test_lock_time_blockheight(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session.client as client:
IF = InputFlowLockTimeBlockHeight(client, "499999999")
client.set_input_flow(IF.get())
with session:
IF = InputFlowLockTimeBlockHeight(session.client, "499999999")
session.set_input_flow(IF.get())
btc.sign_tx(
session,
@ -1506,9 +1506,9 @@ def test_lock_time_datetime(session: Session, lock_time_str: str):
lock_time_utc = lock_time_naive.replace(tzinfo=timezone.utc)
lock_time_timestamp = int(lock_time_utc.timestamp())
with session.client as client:
IF = InputFlowLockTimeDatetime(client, lock_time_str)
client.set_input_flow(IF.get())
with session:
IF = InputFlowLockTimeDatetime(session.client, lock_time_str)
session.set_input_flow(IF.get())
btc.sign_tx(
session,
@ -1538,9 +1538,9 @@ def test_information(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session.client as client:
IF = InputFlowSignTxInformation(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignTxInformation(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(
session,
@ -1573,9 +1573,9 @@ def test_information_mixed(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session.client as client:
IF = InputFlowSignTxInformationMixed(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignTxInformationMixed(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(
session,
@ -1604,9 +1604,9 @@ def test_information_cancel(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session.client as client, pytest.raises(Cancelled):
IF = InputFlowSignTxInformationCancel(client)
client.set_input_flow(IF.get())
with session, pytest.raises(Cancelled):
IF = InputFlowSignTxInformationCancel(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(
session,
@ -1654,9 +1654,9 @@ def test_information_replacement(session: Session):
orig_index=0,
)
with session.client as client:
IF = InputFlowSignTxInformationReplacement(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignTxInformationReplacement(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(
session,

View File

@ -80,10 +80,10 @@ def test_invalid_path_prompt(session: Session):
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
)
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES)
@ -106,10 +106,10 @@ def test_invalid_path_pass_forkid(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS,
)
with session.client as client:
with session:
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES)

View File

@ -203,9 +203,9 @@ def test_payment_request_details(session: Session):
)
]
with session.client as client:
IF = InputFlowPaymentRequestDetails(client, outputs)
client.set_input_flow(IF.get())
with session:
IF = InputFlowPaymentRequestDetails(session.client, outputs)
session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx(
session,

View File

@ -130,11 +130,11 @@ def test_invalid_prev_hash_attack(session: Session, prev_hash):
msg.tx.inputs[0].prev_hash = prev_hash
return msg
with session, session.client as client, pytest.raises(TrezorFailure) as e:
with session, pytest.raises(TrezorFailure) as e:
session.set_filter(messages.TxAck, attack_filter)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(session, "Bitcoin", [inp1], [out1], prev_txes=PREV_TXES)
# check that injection was performed
@ -168,9 +168,9 @@ def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash):
tx_hash = hash_tx(serialize_tx(prev_tx))
inp0.prev_hash = tx_hash
with session, session.client as client, pytest.raises(TrezorFailure) as e:
with session, pytest.raises(TrezorFailure) as e:
if session.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx})
_check_error_message(prev_hash, session.model, e.value.message)

View File

@ -611,11 +611,11 @@ def test_send_multisig_3_change(session: Session):
request_finished(),
]
with session, session.client as client:
with session:
session.set_expected_responses(expected_responses)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
)
@ -626,11 +626,11 @@ def test_send_multisig_3_change(session: Session):
inp1.address_n[2] = H_(3)
out1.address_n[2] = H_(3)
with session, session.client as client:
with session:
session.set_expected_responses(expected_responses)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
)
@ -703,11 +703,11 @@ def test_send_multisig_4_change(session: Session):
request_finished(),
]
with session, session.client as client:
with session:
session.set_expected_responses(expected_responses)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
)
@ -718,11 +718,11 @@ def test_send_multisig_4_change(session: Session):
inp1.address_n[2] = H_(3)
out1.address_n[2] = H_(3)
with session, session.client as client:
with session:
session.set_expected_responses(expected_responses)
if is_core(session):
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
)

View File

@ -40,9 +40,9 @@ def test_message_long_legacy(session: Session):
@pytest.mark.models("core")
def test_message_long_core(session: Session):
with session.client as client:
IF = InputFlowSignVerifyMessageLong(client, verify=True)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignVerifyMessageLong(session.client, verify=True)
session.set_input_flow(IF.get())
ret = btc.verify_message(
session,
"Bitcoin",

View File

@ -95,9 +95,11 @@ def test_cardano_get_address(session: Session, chunkify: bool, parameters, resul
"cardano/get_public_key.derivations.json",
)
def test_cardano_get_public_key(session: Session, parameters, result):
with session, session.client as client:
IF = InputFlowShowXpubQRCode(client, passphrase=bool(session.passphrase))
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowXpubQRCode(
session.client, passphrase=bool(session.passphrase)
)
session.set_input_flow(IF.get())
# session.init_device(new_session=True, derive_cardano=True)
derivation_type = CardanoDerivationType.__members__[

View File

@ -63,7 +63,7 @@ def test_cardano_sign_tx(session: Session, parameters, result):
response = call_sign_tx(
session,
parameters,
input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(),
input_flow=lambda client: InputFlowConfirmAllWarnings(session.client).get(),
)
assert response == _transform_expected_result(result)
@ -122,10 +122,10 @@ def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool =
else:
device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict)
with session.client as client:
with session:
if input_flow is not None:
client.watch_layout()
client.set_input_flow(input_flow(client))
session.client.watch_layout()
session.set_input_flow(input_flow(session.client))
return cardano.sign_tx(
session=session,

View File

@ -29,9 +29,9 @@ from ...input_flows import InputFlowShowXpubQRCode
@pytest.mark.models("t2t1")
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_eos_get_public_key(session: Session):
with session.client as client:
IF = InputFlowShowXpubQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowXpubQRCode(session.client)
session.set_input_flow(IF.get())
public_key = get_public_key(
session, parse_path("m/44h/194h/0h/0/0"), show_display=True
)

View File

@ -123,9 +123,9 @@ def test_external_token(session: Session) -> None:
def test_external_chain_without_token(session: Session) -> None:
with session, session.client as client:
if not client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get())
with session:
if not session.client.debug.legacy_debug:
session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
# when using an external chains, unknown tokens are allowed
network = common.encode_network(chain_id=66666, slip44=60)
params = DEFAULT_ERC20_PARAMS.copy()
@ -145,9 +145,9 @@ def test_external_chain_token_ok(session: Session) -> None:
def test_external_chain_token_mismatch(session: Session) -> None:
with session, session.client as client:
if not client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get())
with session:
if not session.client.debug.legacy_debug:
session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
# when providing external defs, we explicitly allow, but not use, tokens
# from other chains
network = common.encode_network(chain_id=66666, slip44=60)

View File

@ -37,9 +37,9 @@ def test_getaddress(session: Session, parameters, result):
@pytest.mark.models("core", reason="No input flow for T1")
@parametrize_using_common_fixtures("ethereum/getaddress.json")
def test_getaddress_chunkify_details(session: Session, parameters, result):
with session.client as client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
address_n = parse_path(parameters["path"])
assert (
ethereum.get_address(session, address_n, show_display=True, chunkify=True)

View File

@ -97,10 +97,10 @@ DATA = {
@pytest.mark.models("core", skip="delizia", reason="Not yet implemented in new UI")
def test_ethereum_sign_typed_data_show_more_button(session: Session):
with session.client as client:
client.watch_layout()
IF = InputFlowEIP712ShowMore(client)
client.set_input_flow(IF.get())
with session:
session.client.watch_layout()
IF = InputFlowEIP712ShowMore(session.client)
session.set_input_flow(IF.get())
ethereum.sign_typed_data(
session,
parse_path("m/44h/60h/0h/0/0"),
@ -111,10 +111,10 @@ def test_ethereum_sign_typed_data_show_more_button(session: Session):
@pytest.mark.models("core")
def test_ethereum_sign_typed_data_cancel(session: Session):
with session.client as client, pytest.raises(exceptions.Cancelled):
client.watch_layout()
IF = InputFlowEIP712Cancel(client)
client.set_input_flow(IF.get())
with session, pytest.raises(exceptions.Cancelled):
session.client.watch_layout()
IF = InputFlowEIP712Cancel(session.client)
session.set_input_flow(IF.get())
ethereum.sign_typed_data(
session,
parse_path("m/44h/60h/0h/0/0"),

View File

@ -36,9 +36,9 @@ def test_signmessage(session: Session, parameters, result):
assert res.address == result["address"]
assert res.signature.hex() == result["sig"]
else:
with session.client as client:
IF = InputFlowSignVerifyMessageLong(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignVerifyMessageLong(session.client)
session.set_input_flow(IF.get())
res = ethereum.sign_message(
session, parse_path(parameters["path"]), parameters["msg"]
)
@ -57,9 +57,9 @@ def test_verify(session: Session, parameters, result):
)
assert res is True
else:
with session.client as client:
IF = InputFlowSignVerifyMessageLong(client, verify=True)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSignVerifyMessageLong(session.client, verify=True)
session.set_input_flow(IF.get())
res = ethereum.verify_message(
session,
parameters["address"],

View File

@ -73,10 +73,10 @@ def _do_test_signtx(
input_flow=None,
chunkify: bool = False,
):
with session.client as client:
with session:
if input_flow:
client.watch_layout()
client.set_input_flow(input_flow)
session.client.watch_layout()
session.set_input_flow(input_flow)
sig_v, sig_r, sig_s = ethereum.sign_tx(
session,
n=parse_path(parameters["path"]),
@ -151,9 +151,9 @@ def test_signtx_go_back_from_summary(session: Session):
def test_signtx_eip1559(
session: Session, chunkify: bool, parameters: dict, result: dict
):
with session, session.client as client:
if not client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get())
with session:
if not session.client.debug.legacy_debug:
session.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559(
session,
n=parse_path(parameters["path"]),
@ -456,15 +456,15 @@ def test_signtx_data_pagination(session: Session, flow):
data=bytes.fromhex(HEXDATA),
)
with session, session.client as client:
client.watch_layout()
client.set_input_flow(flow(client))
with session:
session.client.watch_layout()
session.set_input_flow(flow(session.client))
_sign_tx_call()
if flow is not input_flow_data_scroll_down:
with session, session.client as client, pytest.raises(exceptions.Cancelled):
client.watch_layout()
client.set_input_flow(flow(client, cancel=True))
with session, pytest.raises(exceptions.Cancelled):
session.client.watch_layout()
session.set_input_flow(flow(session.client, cancel=True))
_sign_tx_call()

View File

@ -33,8 +33,8 @@ def test_encrypt(client: Client):
client.debug.press_yes()
session = client.get_session()
with client, session:
client.set_input_flow(input_flow())
with session:
session.set_input_flow(input_flow())
misc.encrypt_keyvalue(
session,
[],

View File

@ -56,9 +56,9 @@ def test_monero_getaddress(session: Session, path: str, expected_address: bytes)
def test_monero_getaddress_chunkify_details(
session: Session, path: str, expected_address: bytes
):
with session.client as client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
address = monero.get_address(
session, parse_path(path), show_display=True, chunkify=True
)

View File

@ -51,10 +51,10 @@ def do_recover_legacy(session: Session, mnemonic: list[str]):
def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False):
with session.client as client:
client.watch_layout()
IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch)
client.set_input_flow(IF.get())
with session:
session.client.watch_layout()
IF = InputFlowBip39RecoveryDryRun(session.client, mnemonic, mismatch=mismatch)
session.set_input_flow(IF.get())
return device.recover(session, type=messages.RecoveryType.DryRun)
@ -87,10 +87,10 @@ def test_invalid_seed_t1(session: Session):
@pytest.mark.models("core")
def test_invalid_seed_core(session: Session):
with session, session.client as client:
client.watch_layout()
with session:
session.client.watch_layout()
IF = InputFlowBip39RecoveryDryRunInvalid(session)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
return device.recover(
session,

View File

@ -28,9 +28,9 @@ pytestmark = pytest.mark.models("core")
@pytest.mark.setup_client(uninitialized=True)
@pytest.mark.uninitialized_session
def test_tt_pin_passphrase(session: Session):
with session.client as client:
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654")
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "), pin="654")
session.set_input_flow(IF.get())
device.recover(
session,
pin_protection=True,
@ -49,9 +49,9 @@ def test_tt_pin_passphrase(session: Session):
@pytest.mark.setup_client(uninitialized=True)
@pytest.mark.uninitialized_session
def test_tt_nopin_nopassphrase(session: Session):
with session.client as client:
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "))
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "))
session.set_input_flow(IF.get())
device.recover(
session,
pin_protection=False,

View File

@ -48,9 +48,11 @@ VECTORS = (
def _test_secret(
session: Session, shares: list[str], secret: str, click_info: bool = False
):
with session.client as client:
IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39AdvancedRecovery(
session.client, shares, click_info=click_info
)
session.set_input_flow(IF.get())
device.recover(
session,
pin_protection=False,
@ -89,9 +91,9 @@ def test_extra_share_entered(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_abort(session: Session):
with session.client as client:
IF = InputFlowSlip39AdvancedRecoveryAbort(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39AdvancedRecoveryAbort(session.client)
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
session.refresh_features()
@ -100,11 +102,11 @@ def test_abort(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_noabort(session: Session):
with session.client as client:
with session:
IF = InputFlowSlip39AdvancedRecoveryNoAbort(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label")
session.refresh_features()
assert session.features.initialized is True
@ -118,11 +120,11 @@ def test_same_share(session: Session):
# second share is first 4 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4]
with session, session.client as client:
with session:
IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(
session, first_share, second_share
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
@ -134,10 +136,10 @@ def test_group_threshold_reached(session: Session):
# second share is first 3 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3]
with session, session.client as client:
with session:
IF = InputFlowSlip39AdvancedRecoveryThresholdReached(
session, first_share, second_share
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")

View File

@ -40,11 +40,11 @@ EXTRA_GROUP_SHARE = [
@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20, passphrase=False)
def test_2of3_dryrun(session: Session):
with session.client as client:
with session:
IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(
session,
passphrase_protection=False,
@ -57,13 +57,13 @@ def test_2of3_dryrun(session: Session):
@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_ADVANCED_20)
def test_2of3_invalid_seed_dryrun(session: Session):
# test fails because of different seed on device
with session.client as client, pytest.raises(
with session, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device"
):
IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True
session.client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(
session,
passphrase_protection=False,

View File

@ -73,9 +73,9 @@ VECTORS = (
def test_secret(
session: Session, shares: list[str], secret: str, backup_type: messages.BackupType
):
with session.client as client:
IF = InputFlowSlip39BasicRecovery(client, shares)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicRecovery(session.client, shares)
session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended
@ -89,11 +89,11 @@ def test_secret(
@pytest.mark.setup_client(uninitialized=True)
def test_recover_with_pin_passphrase(session: Session):
with session.client as client:
with session:
IF = InputFlowSlip39BasicRecovery(
client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654"
session.client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654"
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(
session,
pin_protection=True,
@ -109,9 +109,9 @@ def test_recover_with_pin_passphrase(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_abort(session: Session):
with session.client as client:
IF = InputFlowSlip39BasicRecoveryAbort(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicRecoveryAbort(session.client)
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
session.refresh_features()
@ -123,9 +123,9 @@ def test_abort(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_abort_on_number_of_words(session: Session):
# on Caesar, test_abort actually aborts on the # of words selection
with session.client as client:
IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(session.client)
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
assert session.features.initialized is False
@ -134,11 +134,11 @@ def test_abort_on_number_of_words(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_abort_between_shares(session: Session):
with session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryAbortBetweenShares(
client, MNEMONIC_SLIP39_BASIC_20_3of6
session.client, MNEMONIC_SLIP39_BASIC_20_3of6
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
session.refresh_features()
@ -148,9 +148,11 @@ def test_abort_between_shares(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_noabort(session: Session):
with session.client as client:
IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicRecoveryNoAbort(
session.client, MNEMONIC_SLIP39_BASIC_20_3of6
)
session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label")
session.refresh_features()
assert session.features.initialized is True
@ -158,9 +160,9 @@ def test_noabort(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_invalid_mnemonic_first_share(session: Session):
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
session.refresh_features()
@ -169,11 +171,11 @@ def test_invalid_mnemonic_first_share(session: Session):
@pytest.mark.setup_client(uninitialized=True)
def test_invalid_mnemonic_second_share(session: Session):
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryInvalidSecondShare(
session, MNEMONIC_SLIP39_BASIC_20_3of6
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
session.refresh_features()
@ -184,9 +186,9 @@ def test_invalid_mnemonic_second_share(session: Session):
@pytest.mark.parametrize("nth_word", range(3))
def test_wrong_nth_word(session: Session, nth_word: int):
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
@ -194,18 +196,18 @@ def test_wrong_nth_word(session: Session, nth_word: int):
@pytest.mark.setup_client(uninitialized=True)
def test_same_share(session: Session):
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoverySameShare(session, share)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label")
@pytest.mark.setup_client(uninitialized=True)
def test_1of1(session: Session):
with session.client as client:
IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicRecovery(session.client, MNEMONIC_SLIP39_BASIC_20_1of1)
session.set_input_flow(IF.get())
device.recover(
session,
pin_protection=False,

View File

@ -38,9 +38,9 @@ INVALID_SHARES_20_2of3 = [
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_dryrun(session: Session):
with session.client as client:
IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3])
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicRecoveryDryRun(session.client, SHARES_20_2of3[1:3])
session.set_input_flow(IF.get())
device.recover(
session,
passphrase_protection=False,
@ -53,13 +53,13 @@ def test_2of3_dryrun(session: Session):
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_invalid_seed_dryrun(session: Session):
# test fails because of different seed on device
with session.client as client, pytest.raises(
with session, pytest.raises(
TrezorFailure, match=r"The seed does not match the one in the device"
):
IF = InputFlowSlip39BasicRecoveryDryRun(
client, INVALID_SHARES_20_2of3, mismatch=True
session.client, INVALID_SHARES_20_2of3, mismatch=True
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(
session,
passphrase_protection=False,

View File

@ -32,9 +32,9 @@ from ...input_flows import (
def backup_flow_bip39(session: Session) -> bytes:
with session.client as client:
IF = InputFlowBip39Backup(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39Backup(session.client)
session.set_input_flow(IF.get())
device.backup(session)
assert IF.mnemonic is not None
@ -42,9 +42,9 @@ def backup_flow_bip39(session: Session) -> bytes:
def backup_flow_slip39_basic(session: Session):
with session.client as client:
IF = InputFlowSlip39BasicBackup(client, False)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicBackup(session.client, False)
session.set_input_flow(IF.get())
device.backup(session)
groups = shamir.decode_mnemonics(IF.mnemonics[:3])
@ -53,9 +53,9 @@ def backup_flow_slip39_basic(session: Session):
def backup_flow_slip39_advanced(session: Session):
with session.client as client:
IF = InputFlowSlip39AdvancedBackup(client, False)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39AdvancedBackup(session.client, False)
session.set_input_flow(IF.get())
device.backup(session)
mnemonics = IF.mnemonics[0:3] + IF.mnemonics[5:8] + IF.mnemonics[10:13]
@ -116,9 +116,9 @@ def test_skip_backup_msg(session: Session, backup_type, backup_flow):
def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow):
assert session.features.initialized is False
with session, session.client as client:
IF = InputFlowResetSkipBackup(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowResetSkipBackup(session.client)
session.set_input_flow(IF.get())
device.setup(
session,
pin_protection=False,

View File

@ -36,9 +36,9 @@ pytestmark = pytest.mark.models("core")
def reset_device(session: Session, strength: int):
debug = session.client.debug
with session.client as client:
IF = InputFlowBip39ResetBackup(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39ResetBackup(session.client)
session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random
device.setup(
@ -92,9 +92,9 @@ def test_reset_device_pin(session: Session):
debug = session.client.debug
strength = 256 # 24 words
with session.client as client:
IF = InputFlowBip39ResetPIN(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39ResetPIN(session.client)
session.set_input_flow(IF.get())
# PIN, passphrase, display random
device.setup(
@ -130,9 +130,9 @@ def test_reset_device_pin(session: Session):
def test_reset_entropy_check(session: Session):
strength = 128 # 12 words
with session.client as client:
IF = InputFlowBip39ResetBackup(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39ResetBackup(session.client)
session.set_input_flow(IF.get())
# No PIN, no passphrase
path_xpubs = device.setup(
@ -147,7 +147,7 @@ def test_reset_entropy_check(session: Session):
)
# Generate the mnemonic locally.
internal_entropy = client.debug.state().reset_entropy
internal_entropy = session.client.debug.state().reset_entropy
assert internal_entropy is not None
entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
@ -156,7 +156,7 @@ def test_reset_entropy_check(session: Session):
assert IF.mnemonic == expected_mnemonic
# Check that the device is properly initialized.
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
if session.client.protocol_version is ProtocolVersion.PROTOCOL_V1:
features = session.call_raw(messages.Initialize())
else:
session.refresh_features()
@ -181,9 +181,9 @@ def test_reset_failed_check(session: Session):
debug = session.client.debug
strength = 256 # 24 words
with session.client as client:
IF = InputFlowBip39ResetFailedCheck(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39ResetFailedCheck(session.client)
session.set_input_flow(IF.get())
# PIN, passphrase, display random
device.setup(

View File

@ -47,9 +47,9 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str:
with session.client as client:
IF = InputFlowBip39ResetBackup(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39ResetBackup(session.client)
session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random
device.setup(
@ -77,10 +77,10 @@ def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> s
def recover(session: Session, mnemonic: str):
words = mnemonic.split(" ")
with session.client as client:
IF = InputFlowBip39Recovery(client, words)
client.set_input_flow(IF.get())
client.watch_layout()
with session:
IF = InputFlowBip39Recovery(session.client, words)
session.set_input_flow(IF.get())
session.client.watch_layout()
device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended

View File

@ -68,9 +68,9 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128) -> list[str]:
with session.client as client:
IF = InputFlowSlip39AdvancedResetRecovery(client, False)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39AdvancedResetRecovery(session.client, False)
session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random
device.setup(
@ -97,9 +97,9 @@ def reset(session: Session, strength: int = 128) -> list[str]:
def recover(session: Session, shares: list[str]):
with session.client as client:
IF = InputFlowSlip39AdvancedRecovery(client, shares, False)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39AdvancedRecovery(session.client, shares, False)
session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended

View File

@ -58,9 +58,9 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128) -> list[str]:
with session.client as client:
IF = InputFlowSlip39BasicResetRecovery(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicResetRecovery(session.client)
session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random
device.setup(
@ -87,9 +87,9 @@ def reset(session: Session, strength: int = 128) -> list[str]:
def recover(session: Session, shares: t.Sequence[str]):
with session.client as client:
IF = InputFlowSlip39BasicRecovery(client, shares)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicRecovery(session.client, shares)
session.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended

View File

@ -34,10 +34,10 @@ def test_reset_device_slip39_advanced(client: Client):
strength = 128
member_threshold = 3
with client:
session = client.get_seedless_session()
with session:
IF = InputFlowSlip39AdvancedResetRecovery(client, False)
client.set_input_flow(IF.get())
session = client.get_seedless_session()
session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random
device.setup(
session,

View File

@ -34,9 +34,9 @@ pytestmark = pytest.mark.models("core")
def reset_device(session: Session, strength: int):
member_threshold = 3
with session.client as client:
IF = InputFlowSlip39BasicResetRecovery(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicResetRecovery(session.client)
session.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random
device.setup(
@ -89,9 +89,9 @@ def test_reset_entropy_check(session: Session):
strength = 128 # 20 words
with session.client as client:
IF = InputFlowSlip39BasicResetRecovery(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicResetRecovery(session.client)
session.set_input_flow(IF.get())
# No PIN, no passphrase.
path_xpubs = device.setup(

View File

@ -52,9 +52,9 @@ def test_ripple_get_address(session: Session, path: str, expected_address: str):
def test_ripple_get_address_chunkify_details(
session: Session, path: str, expected_address: str
):
with session.client as client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
address = get_address(
session, parse_path(path), show_display=True, chunkify=True
)

View File

@ -47,9 +47,9 @@ pytestmark = [
def test_solana_sign_tx(session: Session, parameters, result):
serialized_tx = _serialize_tx(parameters["construct"])
with session.client as client:
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
actual_result = sign_tx(
session,
address_n=parse_path(parameters["address"]),

View File

@ -122,9 +122,9 @@ def test_get_address(session: Session, parameters, result):
@pytest.mark.models("core")
@parametrize_using_common_fixtures("stellar/get_address.json")
def test_get_address_chunkify_details(session: Session, parameters, result):
with session.client as client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
address_n = parse_path(parameters["path"])
address = stellar.get_address(
session, address_n, show_display=True, chunkify=True

View File

@ -38,8 +38,8 @@ def pin_request(session: Session):
def set_autolock_delay(session: Session, delay):
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses(
[
pin_request(session),
@ -61,8 +61,8 @@ def test_apply_auto_lock_delay(session: Session):
get_test_address(session)
time.sleep(10.5) # sleep more than auto-lock delay
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses([pin_request(session), messages.Address])
get_test_address(session)
@ -85,8 +85,8 @@ def test_apply_auto_lock_delay_valid(session: Session, seconds):
def test_autolock_default_value(session: Session):
assert session.features.auto_lock_delay_ms is None
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
device.apply_settings(session, label="pls unlock")
session.refresh_features()
assert session.features.auto_lock_delay_ms == 60 * 10 * 1000
@ -98,8 +98,8 @@ def test_autolock_default_value(session: Session):
)
def test_apply_auto_lock_delay_out_of_range(session: Session, seconds):
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses(
[
pin_request(session),

View File

@ -48,8 +48,8 @@ def test_busy_state(session: Session):
_assert_busy(session, True)
assert session.features.unlocked is False
with session.client as client:
client.use_pin_sequence([PIN])
with session:
session.client.use_pin_sequence([PIN])
btc.get_address(
session, "Bitcoin", parse_path("m/44h/0h/0h/0/0"), show_display=True
)

View File

@ -40,9 +40,9 @@ def test_cancel_message_via_cancel(session: Session, message):
yield
session.cancel()
with session, session.client as client, pytest.raises(Cancelled):
with session, pytest.raises(Cancelled):
session.set_expected_responses([m.ButtonRequest(), m.Failure()])
client.set_input_flow(input_flow)
session.set_input_flow(input_flow)
session.call(message)

View File

@ -47,12 +47,12 @@ def test_pin(session: Session):
)
assert isinstance(resp, messages.PinMatrixRequest)
with session.client as client:
state = client.debug.state()
with session:
state = session.client.debug.state()
assert state.pin == "1234"
assert state.matrix != ""
pin_encoded = client.debug.encode_pin("1234")
pin_encoded = session.client.debug.encode_pin("1234")
resp = session.call_raw(messages.PinMatrixAck(pin=pin_encoded))
assert isinstance(resp, messages.PassphraseRequest)

View File

@ -79,9 +79,9 @@ def _check_ping_screen_texts(session: Session, title: str, right_button: str) ->
if session.model in (models.T2T1, models.T3T1):
right_button = "-"
with session, session.client as client:
client.watch_layout(True)
client.set_input_flow(ping_input_flow(session, title, right_button))
with session:
session.client.watch_layout(True)
session.set_input_flow(ping_input_flow(session, title, right_button))
ping = session.call(messages.Ping(message="ahoj!", button_protection=True))
assert ping == messages.Success(message="ahoj!")
@ -274,8 +274,8 @@ def test_reject_update(session: Session):
yield
session.client.debug.press_no()
with pytest.raises(exceptions.Cancelled), session, session.client as client:
client.set_input_flow(input_flow_reject)
with pytest.raises(exceptions.Cancelled), session:
session.set_input_flow(input_flow_reject)
device.change_language(session, language_data)
assert session.features.language == "en-US"

View File

@ -345,12 +345,12 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways
with session, session.client as client:
with session:
session.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
)
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
get_bad_address()
with session:
@ -371,13 +371,13 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily
with session, session.client as client:
with session:
session.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
)
if session.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
IF = InputFlowConfirmAllWarnings(session.client)
session.set_input_flow(IF.get())
get_bad_address()
@ -412,8 +412,8 @@ def test_experimental_features(session: Session):
# relock and try again
session.lock()
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses([messages.ButtonRequest, messages.Nonce])
experimental_call()

View File

@ -44,9 +44,9 @@ from ..input_flows import (
def test_backup_bip39(session: Session):
assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client:
IF = InputFlowBip39Backup(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowBip39Backup(session.client)
session.set_input_flow(IF.get())
device.backup(session)
assert IF.mnemonic == MNEMONIC12
@ -71,9 +71,9 @@ def test_backup_slip39_basic(session: Session, click_info: bool):
assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client:
IF = InputFlowSlip39BasicBackup(client, click_info)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicBackup(session.client, click_info)
session.set_input_flow(IF.get())
device.backup(session)
session.refresh_features()
@ -95,11 +95,12 @@ def test_backup_slip39_basic(session: Session, click_info: bool):
def test_backup_slip39_single(session: Session):
assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client:
with session:
IF = InputFlowBip39Backup(
client, confirm_success=(client.layout_type is not LayoutType.Delizia)
session.client,
confirm_success=(session.client.layout_type is not LayoutType.Delizia),
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.backup(session)
assert session.features.initialized is True
@ -126,9 +127,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool):
assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client:
IF = InputFlowSlip39AdvancedBackup(client, click_info)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39AdvancedBackup(session.client, click_info)
session.set_input_flow(IF.get())
device.backup(session)
session.refresh_features()
@ -157,9 +158,9 @@ def test_backup_slip39_advanced(session: Session, click_info: bool):
def test_backup_slip39_custom(session: Session, share_threshold, share_count):
assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client:
IF = InputFlowSlip39CustomBackup(client, share_count)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39CustomBackup(session.client, share_count)
session.set_input_flow(IF.get())
device.backup(
session, group_threshold=1, groups=[(share_threshold, share_count)]
)

View File

@ -34,7 +34,7 @@ pytestmark = pytest.mark.models("legacy")
def _set_wipe_code(session: Session, pin, wipe_code):
# Set/change wipe code.
with session.client as client, session:
with session:
if session.features.pin_protection:
pins = [pin, wipe_code, wipe_code]
pin_matrices = [
@ -49,7 +49,7 @@ def _set_wipe_code(session: Session, pin, wipe_code):
messages.PinMatrixRequest(type=PinType.WipeCodeSecond),
]
client.use_pin_sequence(pins)
session.client.use_pin_sequence(pins)
session.set_expected_responses(
[messages.ButtonRequest()] + pin_matrices + [messages.Success]
)
@ -58,8 +58,8 @@ def _set_wipe_code(session: Session, pin, wipe_code):
def _change_pin(session: Session, old_pin, new_pin):
assert session.features.pin_protection is True
with session.client as client:
client.use_pin_sequence([old_pin, new_pin, new_pin])
with session:
session.client.use_pin_sequence([old_pin, new_pin, new_pin])
try:
return device.change_pin(session)
except exceptions.TrezorFailure as f:
@ -96,8 +96,8 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE6)
# Test remove wipe code.
with session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
device.change_wipe_code(session, remove=True)
# Check that there's no wipe code protection now.
@ -111,8 +111,8 @@ def test_set_wipe_code_mismatch(session: Session):
assert session.features.wipe_code_protection is False
# Let's set a new wipe code.
with session.client as client, session:
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6])
with session:
session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6])
session.set_expected_responses(
[
messages.ButtonRequest(),
@ -125,8 +125,8 @@ def test_set_wipe_code_mismatch(session: Session):
device.change_wipe_code(session)
# Check that there is no wipe code protection.
client.refresh_features()
assert client.features.wipe_code_protection is False
session.client.refresh_features()
assert session.client.features.wipe_code_protection is False
@pytest.mark.setup_client(pin=PIN4)
@ -135,8 +135,8 @@ def test_set_wipe_code_to_pin(session: Session):
assert session.features.wipe_code_protection is None
# Let's try setting the wipe code to the curent PIN value.
with session.client as client, session:
client.use_pin_sequence([PIN4, PIN4])
with session:
session.client.use_pin_sequence([PIN4, PIN4])
session.set_expected_responses(
[
messages.ButtonRequest(),
@ -149,8 +149,8 @@ def test_set_wipe_code_to_pin(session: Session):
device.change_wipe_code(session)
# Check that there is no wipe code protection.
client.refresh_features()
assert client.features.wipe_code_protection is False
session.client.refresh_features()
assert session.client.features.wipe_code_protection is False
def test_set_pin_to_wipe_code(session: Session):
@ -159,8 +159,8 @@ def test_set_pin_to_wipe_code(session: Session):
_set_wipe_code(session, None, WIPE_CODE4)
# Try to set the PIN to the current wipe code value.
with session.client as client, session:
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
with session:
session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
session.set_expected_responses(
[
messages.ButtonRequest(),

View File

@ -37,8 +37,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str):
assert session.features.wipe_code_protection is True
# Try to change the PIN to the current wipe code value. The operation should fail.
with session, session.client as client, pytest.raises(TrezorFailure):
client.use_pin_sequence([pin, wipe_code, wipe_code])
with session, pytest.raises(TrezorFailure):
session.client.use_pin_sequence([pin, wipe_code, wipe_code])
if session.client.layout_type is LayoutType.Caesar:
br_count = 6
else:
@ -51,8 +51,8 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str):
def _ensure_unlocked(session: Session, pin: str):
with session, session.client as client:
client.use_pin_sequence([pin])
with session:
session.client.use_pin_sequence([pin])
btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
session.refresh_features()
@ -71,11 +71,11 @@ def test_set_remove_wipe_code(session: Session):
else:
br_count = 5
with session, session.client as client:
with session:
session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success]
)
client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX])
session.client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX])
device.change_wipe_code(session)
# session.init_device()
@ -83,11 +83,11 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE_MAX)
# Test change wipe code.
with session, session.client as client:
with session:
session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success]
)
client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6])
session.client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6])
device.change_wipe_code(session)
# session.init_device()
@ -95,11 +95,11 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE6)
# Test remove wipe code.
with session, session.client as client:
with session:
session.set_expected_responses(
[messages.ButtonRequest()] * 3 + [messages.Success]
)
client.use_pin_sequence([PIN4])
session.client.use_pin_sequence([PIN4])
device.change_wipe_code(session, remove=True)
# session.init_device()
@ -107,9 +107,11 @@ def test_set_remove_wipe_code(session: Session):
def test_set_wipe_code_mismatch(session: Session):
with session, session.client as client, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(client, WIPE_CODE4, WIPE_CODE6, what="wipe_code")
client.set_input_flow(IF.get())
with session, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(
session.client, WIPE_CODE4, WIPE_CODE6, what="wipe_code"
)
session.set_input_flow(IF.get())
device.change_wipe_code(session)
@ -122,15 +124,15 @@ def test_set_wipe_code_mismatch(session: Session):
def test_set_wipe_code_to_pin(session: Session):
_ensure_unlocked(session, PIN4)
with session, session.client as client:
if client.layout_type is LayoutType.Caesar:
with session:
if session.client.layout_type is LayoutType.Caesar:
br_count = 8
else:
br_count = 7
session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success],
)
client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4])
session.client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4])
device.change_wipe_code(session)
# session.init_device()
@ -140,20 +142,20 @@ def test_set_wipe_code_to_pin(session: Session):
def test_set_pin_to_wipe_code(session: Session):
# Set wipe code.
with session, session.client as client:
if client.layout_type is LayoutType.Caesar:
with session:
if session.client.layout_type is LayoutType.Caesar:
br_count = 5
else:
br_count = 4
session.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success]
)
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
device.change_wipe_code(session)
# Try to set the PIN to the current wipe code value.
with session, session.client as client, pytest.raises(TrezorFailure):
if client.layout_type is LayoutType.Caesar:
with session, pytest.raises(TrezorFailure):
if session.client.layout_type is LayoutType.Caesar:
br_count = 6
else:
br_count = 4
@ -161,5 +163,5 @@ def test_set_pin_to_wipe_code(session: Session):
[messages.ButtonRequest()] * br_count
+ [messages.Failure(code=messages.FailureType.PinInvalid)]
)
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
session.client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
device.change_pin(session)

View File

@ -33,8 +33,8 @@ pytestmark = pytest.mark.models("legacy")
def _check_pin(session: Session, pin):
session.lock()
with session, session.client as client:
client.use_pin_sequence([pin])
with session:
session.client.use_pin_sequence([pin])
session.set_expected_responses([messages.PinMatrixRequest, messages.Address])
get_test_address(session)
@ -53,8 +53,8 @@ def test_set_pin(session: Session):
_check_no_pin(session)
# Let's set new PIN
with session, session.client as client:
client.use_pin_sequence([PIN_MAX, PIN_MAX])
with session:
session.client.use_pin_sequence([PIN_MAX, PIN_MAX])
session.set_expected_responses(
[
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -78,8 +78,8 @@ def test_change_pin(session: Session):
_check_pin(session, PIN4)
# Let's change PIN
with session, session.client as client:
client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
with session:
session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
session.set_expected_responses(
[
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -104,8 +104,8 @@ def test_remove_pin(session: Session):
_check_pin(session, PIN4)
# Let's remove PIN
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses(
[
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -126,11 +126,9 @@ def test_set_mismatch(session: Session):
_check_no_pin(session)
# Let's set new PIN
with session, session.client as client, pytest.raises(
TrezorFailure, match="PIN mismatch"
):
with session, pytest.raises(TrezorFailure, match="PIN mismatch"):
# use different PINs for first and second attempt. This will fail.
client.use_pin_sequence([PIN4, PIN_MAX])
session.client.use_pin_sequence([PIN4, PIN_MAX])
session.set_expected_responses(
[
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
@ -152,10 +150,8 @@ def test_change_mismatch(session: Session):
assert session.features.pin_protection is True
# Let's set new PIN
with session, session.client as client, pytest.raises(
TrezorFailure, match="PIN mismatch"
):
client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"])
with session, pytest.raises(TrezorFailure, match="PIN mismatch"):
session.client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"])
session.set_expected_responses(
[
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),

View File

@ -37,9 +37,9 @@ pytestmark = pytest.mark.models("core")
def _check_pin(session: Session, pin: str):
with session, session.client as client:
client.ui.__init__(client.debug)
client.use_pin_sequence([pin, pin, pin, pin, pin, pin])
with session:
session.client.ui.__init__(session.client.debug)
session.client.use_pin_sequence([pin, pin, pin, pin, pin, pin])
session.lock()
assert session.features.pin_protection is True
assert session.features.unlocked is False
@ -63,12 +63,12 @@ def test_set_pin(session: Session):
_check_no_pin(session)
# Let's set new PIN
with session, session.client as client:
if client.layout_type is LayoutType.Caesar:
with session:
if session.client.layout_type is LayoutType.Caesar:
br_count = 6
else:
br_count = 4
client.use_pin_sequence([PIN_MAX, PIN_MAX])
session.client.use_pin_sequence([PIN_MAX, PIN_MAX])
session.set_expected_responses(
[messages.ButtonRequest] * br_count + [messages.Success]
)
@ -86,9 +86,9 @@ def test_change_pin(session: Session):
_check_pin(session, PIN4)
# Let's change PIN
with session, session.client as client:
client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
if client.layout_type is LayoutType.Caesar:
with session:
session.client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
if session.client.layout_type is LayoutType.Caesar:
br_count = 6
else:
br_count = 5
@ -113,8 +113,8 @@ def test_remove_pin(session: Session):
_check_pin(session, PIN4)
# Let's remove PIN
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses(
[messages.ButtonRequest] * 3 + [messages.Success]
)
@ -132,9 +132,9 @@ def test_set_failed(session: Session):
# Check that there's no PIN protection
_check_no_pin(session)
with session, session.client as client, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(client, PIN4, PIN60, what="pin")
client.set_input_flow(IF.get())
with session, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(session.client, PIN4, PIN60, what="pin")
session.set_input_flow(IF.get())
device.change_pin(session)
@ -151,9 +151,9 @@ def test_change_failed(session: Session):
# Check current PIN value
_check_pin(session, PIN4)
with session, session.client as client, pytest.raises(Cancelled):
with session, pytest.raises(Cancelled):
IF = InputFlowCodeChangeFail(session, PIN4, "457891", "381847")
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.change_pin(session)
@ -170,9 +170,9 @@ def test_change_invalid_current(session: Session):
# Check current PIN value
_check_pin(session, PIN4)
with session, session.client as client, pytest.raises(TrezorFailure):
IF = InputFlowWrongPIN(client, PIN60)
client.set_input_flow(IF.get())
with session, pytest.raises(TrezorFailure):
IF = InputFlowWrongPIN(session.client, PIN60)
session.set_input_flow(IF.get())
device.change_pin(session)
@ -200,7 +200,7 @@ def test_pin_menu_cancel_setup(session: Session):
# tap to confirm
debug.click(debug.screen_buttons.tap_to_confirm())
with session, session.client as client, pytest.raises(Cancelled):
client.set_input_flow(cancel_pin_setup_input_flow)
with session, pytest.raises(Cancelled):
session.set_input_flow(cancel_pin_setup_input_flow)
session.call(messages.ChangePin())
_check_no_pin(session)

View File

@ -45,9 +45,8 @@ def test_wipe_device(client: Client):
@pytest.mark.setup_client(pin=PIN4)
def test_autolock_not_retained(session: Session):
client = session.client
with client:
client.use_pin_sequence([PIN4])
device.apply_settings(session, auto_lock_delay_ms=10_000)
client.use_pin_sequence([PIN4])
device.apply_settings(session, auto_lock_delay_ms=10_000)
assert session.features.auto_lock_delay_ms == 10_000
@ -57,21 +56,20 @@ def test_autolock_not_retained(session: Session):
assert client.features.auto_lock_delay_ms > 10_000
with client:
client.use_pin_sequence([PIN4, PIN4])
device.setup(
session,
skip_backup=True,
pin_protection=True,
passphrase_protection=False,
entropy_check_count=0,
backup_type=messages.BackupType.Bip39,
)
client.use_pin_sequence([PIN4, PIN4])
device.setup(
session,
skip_backup=True,
pin_protection=True,
passphrase_protection=False,
entropy_check_count=0,
backup_type=messages.BackupType.Bip39,
)
time.sleep(10.5)
session = client.get_session()
with session, client:
with session:
# after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked
session.set_expected_responses([messages.Address])
get_test_address(session)

View File

@ -39,8 +39,8 @@ def test_no_protection(session: Session):
def test_correct_pin(session: Session):
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
# Expected responses differ between T1 and TT
is_t1 = session.model is models.T1B1
session.set_expected_responses(
@ -65,9 +65,9 @@ def test_incorrect_pin_t1(session: Session):
@pytest.mark.models("core")
def test_incorrect_pin_t2(session: Session):
with session, session.client as client:
with session:
# After first incorrect attempt, TT will not raise an error, but instead ask for another attempt
client.use_pin_sequence([BAD_PIN, PIN4])
session.client.use_pin_sequence([BAD_PIN, PIN4])
session.set_expected_responses(
[
messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry),
@ -82,15 +82,15 @@ def test_incorrect_pin_t2(session: Session):
def test_exponential_backoff_t1(session: Session):
for attempt in range(3):
start = time.time()
with session, session.client as client, pytest.raises(PinException):
client.use_pin_sequence([BAD_PIN])
with session, pytest.raises(PinException):
session.client.use_pin_sequence([BAD_PIN])
get_test_address(session)
check_pin_backoff_time(attempt, start)
@pytest.mark.models("core")
def test_exponential_backoff_t2(session: Session):
with session.client as client:
IF = InputFlowPINBackoff(client, BAD_PIN, PIN4)
client.set_input_flow(IF.get())
with session:
IF = InputFlowPINBackoff(session.client, BAD_PIN, PIN4)
session.set_input_flow(IF.get())
get_test_address(session)

View File

@ -56,12 +56,12 @@ def _assert_protection(
session: Session, pin: bool = True, passphrase: bool = True
) -> Session:
"""Make sure PIN and passphrase protection have expected values"""
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.ensure_unlocked()
client.refresh_features()
assert client.features.pin_protection is pin
assert client.features.passphrase_protection is passphrase
session.client.refresh_features()
assert session.client.features.pin_protection is pin
assert session.client.features.passphrase_protection is passphrase
session.lock()
# session.end()
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
@ -70,8 +70,8 @@ def _assert_protection(
def test_initialize(session: Session):
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.ensure_unlocked()
session = _assert_protection(session)
with session:
@ -86,8 +86,8 @@ def test_passphrase_reporting(session: Session, passphrase):
"""On TT, passphrase_protection is a private setting, so a locked device should
report passphrase_protection=None.
"""
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
device.apply_settings(session, use_passphrase=passphrase)
session.lock()
@ -108,8 +108,8 @@ def test_passphrase_reporting(session: Session, passphrase):
def test_apply_settings(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses(
[
_pin_request(session),
@ -124,8 +124,8 @@ def test_apply_settings(session: Session):
@pytest.mark.models("legacy")
def test_change_pin_t1(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4, PIN4, PIN4])
with session:
session.client.use_pin_sequence([PIN4, PIN4, PIN4])
session.set_expected_responses(
[
messages.ButtonRequest,
@ -141,8 +141,8 @@ def test_change_pin_t1(session: Session):
@pytest.mark.models("core")
def test_change_pin_t2(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4])
with session:
session.client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4])
session.set_expected_responses(
[
_pin_request(session),
@ -172,8 +172,8 @@ def test_ping(session: Session):
def test_get_entropy(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses(
[
_pin_request(session),
@ -187,8 +187,8 @@ def test_get_entropy(session: Session):
def test_get_public_key(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
@ -202,8 +202,8 @@ def test_get_public_key(session: Session):
def test_get_address(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
@ -221,8 +221,8 @@ def test_wipe_device(session: Session):
device.wipe(session)
client = session.client.get_new_client()
session = client.get_seedless_session()
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses([messages.Features])
session.call(messages.GetFeatures())
@ -301,8 +301,8 @@ def test_recovery_device(session: Session):
def test_sign_message(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
@ -350,8 +350,8 @@ def test_verify_message_t1(session: Session):
@pytest.mark.models("core")
def test_verify_message_t2(session: Session):
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses(
[
_pin_request(session),
@ -389,8 +389,8 @@ def test_signtx(session: Session):
)
session = _assert_protection(session)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
@ -430,8 +430,8 @@ def test_unlocked(session: Session):
session = _assert_protection(session, passphrase=False)
with session, session.client as client:
client.use_pin_sequence([PIN4])
with session:
session.client.use_pin_sequence([PIN4])
session.set_expected_responses([_pin_request(session), messages.Address])
get_test_address(session)

View File

@ -39,9 +39,9 @@ def test_repeated_backup(session: Session):
# initial device backup
mnemonics = []
with session, session.client as client:
IF = InputFlowSlip39BasicBackup(client, False)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicBackup(session.client, False)
session.set_input_flow(IF.get())
device.backup(session)
mnemonics = IF.mnemonics
@ -56,11 +56,11 @@ def test_repeated_backup(session: Session):
device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True
session.client, mnemonics[:3], unlock_repeated_backup=True
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert (
session.features.backup_availability
@ -69,9 +69,9 @@ def test_repeated_backup(session: Session):
assert session.features.recovery_status == messages.RecoveryStatus.Backup
# we can now perform another backup
with session, session.client as client:
IF = InputFlowSlip39BasicBackup(client, False, repeated=True)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True)
session.set_input_flow(IF.get())
device.backup(session)
# the backup feature is locked again...
@ -92,11 +92,11 @@ def test_repeated_backup_upgrade_single(session: Session):
assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable
# unlock repeated backup by entering the single share
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryDryRun(
client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True
session.client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert (
session.features.backup_availability
@ -105,9 +105,9 @@ def test_repeated_backup_upgrade_single(session: Session):
assert session.features.recovery_status == messages.RecoveryStatus.Backup
# we can now perform another backup
with session, session.client as client:
IF = InputFlowSlip39BasicBackup(client, False, repeated=True)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicBackup(session.client, False, repeated=True)
session.set_input_flow(IF.get())
device.backup(session)
# backup type was upgraded:
@ -128,9 +128,9 @@ def test_repeated_backup_cancel(session: Session):
# initial device backup
mnemonics = []
with session, session.client as client:
IF = InputFlowSlip39BasicBackup(client, False)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicBackup(session.client, False)
session.set_input_flow(IF.get())
device.backup(session)
mnemonics = IF.mnemonics
@ -145,11 +145,11 @@ def test_repeated_backup_cancel(session: Session):
device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True
session.client, mnemonics[:3], unlock_repeated_backup=True
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert (
session.features.backup_availability
@ -183,9 +183,9 @@ def test_repeated_backup_send_disallowed_message(session: Session):
# initial device backup
mnemonics = []
with session, session.client as client:
IF = InputFlowSlip39BasicBackup(client, False)
client.set_input_flow(IF.get())
with session:
IF = InputFlowSlip39BasicBackup(session.client, False)
session.set_input_flow(IF.get())
device.backup(session)
mnemonics = IF.mnemonics
@ -200,11 +200,11 @@ def test_repeated_backup_send_disallowed_message(session: Session):
device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client:
with session:
IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True
session.client, mnemonics[:3], unlock_repeated_backup=True
)
client.set_input_flow(IF.get())
session.set_input_flow(IF.get())
device.recover(session, type=messages.RecoveryType.UnlockRepeatedBackup)
assert (
session.features.backup_availability

View File

@ -45,8 +45,8 @@ def test_sd_no_format(session: Session):
yield # format SD card
debug.press_no()
with session, session.client as client, pytest.raises(TrezorFailure) as e:
client.set_input_flow(input_flow)
with session, pytest.raises(TrezorFailure) as e:
session.set_input_flow(input_flow)
device.sd_protect(session, Op.ENABLE)
assert e.value.code == messages.FailureType.ProcessError
@ -76,9 +76,9 @@ def test_sd_protect_unlock(session: Session):
assert TR.sd_card__enabled in layout().text_content()
debug.press_yes()
with session, session.client as client:
client.watch_layout()
client.set_input_flow(input_flow_enable_sd_protect)
with session:
session.client.watch_layout()
session.set_input_flow(input_flow_enable_sd_protect)
device.sd_protect(session, Op.ENABLE)
def input_flow_change_pin():
@ -102,9 +102,9 @@ def test_sd_protect_unlock(session: Session):
assert TR.pin__changed in layout().text_content()
debug.press_yes()
with session, session.client as client:
client.watch_layout()
client.set_input_flow(input_flow_change_pin)
with session:
session.client.watch_layout()
session.set_input_flow(input_flow_change_pin)
device.change_pin(session)
debug.erase_sd_card(format=False)
@ -125,9 +125,9 @@ def test_sd_protect_unlock(session: Session):
)
debug.press_no() # close
with session, session.client as client, pytest.raises(TrezorFailure) as e:
client.watch_layout()
client.set_input_flow(input_flow_change_pin_format)
with session, pytest.raises(TrezorFailure) as e:
session.client.watch_layout()
session.set_input_flow(input_flow_change_pin_format)
device.change_pin(session)
assert e.value.code == messages.FailureType.ProcessError

View File

@ -41,7 +41,7 @@ def test_clear_session(client: Client):
cached_responses = [messages.PublicKey]
session = client.get_session()
session.lock()
with client, session:
with session:
client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB
@ -57,7 +57,7 @@ def test_clear_session(client: Client):
session = client.get_session()
# session cache is cleared
with client, session:
with session:
client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB
@ -76,7 +76,7 @@ def test_end_session(client: Client):
assert session.id is not None
# get_address will succeed
with session:
with session as session:
session.set_expected_responses([messages.Address])
get_test_address(session)
@ -135,7 +135,7 @@ def test_end_session_only_current(client: Client):
@pytest.mark.setup_client(passphrase=True)
def test_session_recycling(client: Client):
session = client.get_session(passphrase="TREZOR")
with client, session:
with session:
session.set_expected_responses(
[
messages.PassphraseRequest,
@ -152,7 +152,7 @@ def test_session_recycling(client: Client):
session_x.end()
# it should still be possible to resume the original session
with client, session:
with session:
# passphrase should still be cached
session.set_expected_responses([messages.Address] * 3)
client.resume_session(session)

View File

@ -396,7 +396,7 @@ def test_passphrase_length(client: Client):
def test_hide_passphrase_from_host(client: Client):
# Without safety checks, turning it on fails
session = client.get_seedless_session()
with pytest.raises(TrezorFailure, match="Safety checks are strict"), client:
with pytest.raises(TrezorFailure, match="Safety checks are strict"):
device.apply_settings(session, hide_passphrase_from_host=True)
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
@ -406,7 +406,7 @@ def test_hide_passphrase_from_host(client: Client):
passphrase = "abc"
session = client.get_session(passphrase=passphrase)
with client, session:
with session:
def input_flow():
yield
@ -421,8 +421,8 @@ def test_hide_passphrase_from_host(client: Client):
else:
raise KeyError
client.watch_layout()
client.set_input_flow(input_flow)
session.client.watch_layout()
session.set_input_flow(input_flow)
session.set_expected_responses(
[
messages.PassphraseRequest,
@ -440,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client):
# Starting new session, otherwise the passphrase would be cached
session = client.get_session(passphrase=passphrase)
with client, session:
with session:
def input_flow():
yield
@ -455,8 +455,8 @@ def test_hide_passphrase_from_host(client: Client):
assert passphrase in client.debug.read_layout().text_content()
client.debug.press_yes()
client.watch_layout()
client.set_input_flow(input_flow)
session.client.watch_layout()
session.set_input_flow(input_flow)
session.set_expected_responses(
[
messages.PassphraseRequest,

View File

@ -44,9 +44,9 @@ def test_tezos_get_address(session: Session, path: str, expected_address: str):
def test_tezos_get_address_chunkify_details(
session: Session, path: str, expected_address: str
):
with session.client as client:
IF = InputFlowShowAddressQRCode(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowShowAddressQRCode(session.client)
session.set_input_flow(IF.get())
address = get_address(
session, parse_path(path), show_display=True, chunkify=True
)

View File

@ -31,9 +31,9 @@ RK_CAPACITY = 100
@pytest.mark.altcoin
@pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_add_remove(session: Session):
with session, session.client as client:
IF = InputFlowFidoConfirm(client)
client.set_input_flow(IF.get())
with session:
IF = InputFlowFidoConfirm(session.client)
session.set_input_flow(IF.get())
# Remove index 0 should fail.
with pytest.raises(TrezorFailure):

View File

@ -50,16 +50,18 @@ class InputFlowBase:
# There could be one common input flow for all models
if hasattr(self, "input_flow_common"):
return getattr(self, "input_flow_common")
flow = getattr(self, "input_flow_common")
elif self.client.layout_type is LayoutType.Bolt:
return self.input_flow_bolt
flow = self.input_flow_bolt
elif self.client.layout_type is LayoutType.Caesar:
return self.input_flow_caesar
flow = self.input_flow_caesar
elif self.client.layout_type is LayoutType.Delizia:
return self.input_flow_delizia
flow = self.input_flow_delizia
else:
raise ValueError("Unknown model")
return flow
def input_flow_bolt(self) -> BRGeneratorType:
"""Special for TT"""
raise NotImplementedError
@ -371,7 +373,7 @@ class InputFlowSignMessageInfo(InputFlowBase):
self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1])
# address mismatch? yes!
self.debug.swipe_up()
yield
yield # ?
class InputFlowShowAddressQRCode(InputFlowBase):

View File

@ -11,33 +11,37 @@ WIPE_CODE = "9876"
def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session())
session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device(
client.get_seedless_session(),
session,
MNEMONIC12,
pin,
passphrase_protection=False,
label="WIPECODE",
)
with client:
with session:
client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
device.change_wipe_code(client.get_seedless_session())
def setup_device_core(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session())
session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device(
client.get_seedless_session(),
session,
MNEMONIC12,
pin,
passphrase_protection=False,
label="WIPECODE",
)
with client:
with session:
client.use_pin_sequence([pin, wipe_code, wipe_code])
device.change_wipe_code(client.get_seedless_session())

View File

@ -96,7 +96,7 @@ def test_upgrade_load_pin(gen: str, tag: str) -> None:
assert client.features.initialized
assert client.features.label == LABEL
session = client.get_session()
with client, session:
with session:
client.use_pin_sequence([PIN])
assert btc.get_address(session, "Bitcoin", PATH) == ADDRESS
@ -395,10 +395,11 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
# Create a backup of the encrypted master secret.
assert emu.client.features.backup_availability == BackupAvailability.Required
with emu.client:
session = emu.client.get_session()
with session:
IF = InputFlowSlip39BasicBackup(emu.client, False)
emu.client.set_input_flow(IF.get())
device.backup(emu.client.get_session())
session.set_input_flow(IF.get())
device.backup(session)
assert (
emu.client.features.backup_availability == BackupAvailability.NotAvailable
)