1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-14 12:58:46 +00:00

refactor(tests): move set_input_flow to SessionDebugWrapper context manager

[no changelog]
This commit is contained in:
Martin Milata 2025-03-06 00:35:51 +01:00 committed by M1nd3r
parent 3ab0db670c
commit eb2dd59d50
89 changed files with 979 additions and 1068 deletions

View File

@ -788,10 +788,10 @@ class DebugUI:
def __init__(self, debuglink: DebugLink) -> None: def __init__(self, debuglink: DebugLink) -> None:
self.debuglink = debuglink self.debuglink = debuglink
self.pins: t.Iterator[str] | None = None
self.clear() self.clear()
def clear(self) -> None: def clear(self) -> None:
self.pins: t.Iterator[str] | None = None
self.passphrase = None self.passphrase = None
self.input_flow: t.Union[ self.input_flow: t.Union[
t.Generator[None, messages.ButtonRequest, None], object, None t.Generator[None, messages.ButtonRequest, None], object, None
@ -947,7 +947,6 @@ class SessionDebugWrapper(Session):
if isinstance(session, SessionDebugWrapper): if isinstance(session, SessionDebugWrapper):
raise Exception("Cannot wrap already wrapped session!") raise Exception("Cannot wrap already wrapped session!")
self.__dict__["_session"] = session self.__dict__["_session"] = session
self.reset_debug_features()
def __getattr__(self, name: str) -> t.Any: def __getattr__(self, name: str) -> t.Any:
return getattr(self._session, name) return getattr(self._session, name)
@ -962,61 +961,24 @@ class SessionDebugWrapper(Session):
def protocol_version(self) -> int: def protocol_version(self) -> int:
return self.client.protocol_version return self.client.protocol_version
@property
def debug_client(self) -> TrezorClientDebugLink:
if not isinstance(self.client, TrezorClientDebugLink):
raise Exception("Debug client not available")
return self.client
def _write(self, msg: t.Any) -> None: def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__) self._session._write(self.debug_client._filter_message(msg))
self._session._write(self._filter_message(msg))
def _read(self) -> t.Any: def _read(self) -> t.Any:
resp = self._filter_message(self._session._read()) resp = self.debug_client._filter_message(self._session._read())
print("reading message:", resp.__class__.__name__) if self.debug_client.actual_responses is not None:
if self.actual_responses is not None: self.debug_client.actual_responses.append(resp)
self.actual_responses.append(resp)
return resp return resp
def resume(self) -> None: def resume(self) -> None:
self._session.resume() self._session.resume()
def set_expected_responses(
self,
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
) -> None:
"""Set a sequence of expected responses to session calls.
Within a given with-block, the list of received responses from device must
match the list of expected responses, otherwise an ``AssertionError`` is raised.
If an expected response is given a field value other than ``None``, that field value
must exactly match the received field value. If a given field is ``None``
(or unspecified) in the expected response, the received field value is not
checked.
Each expected response can also be a tuple ``(bool, message)``. In that case, the
expected response is only evaluated if the first field is ``True``.
This is useful for differentiating sequences between Trezor models:
>>> trezor_one = session.features.model == "1"
>>> session.set_expected_responses([
>>> messages.ButtonRequest(code=ConfirmOutput),
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
>>> messages.Success(),
>>> ])
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
# make sure all items are (bool, message) tuples
expected_with_validity = (
e if isinstance(e, tuple) else (True, e) for e in expected
)
# only apply those items that are (True, message)
self.expected_responses = [
MessageFilter.from_message_or_type(expected)
for valid, expected in expected_with_validity
if valid
]
self.actual_responses = []
def lock(self) -> None: def lock(self) -> None:
"""Lock the device. """Lock the device.
@ -1037,6 +999,214 @@ class SessionDebugWrapper(Session):
btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH) btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH)
self.refresh_features() self.refresh_features()
class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses
# and other functionality for unit tests
# for various callbacks, created in order
# to automatically pass unit tests.
#
# This mixing should be used only for purposes
# of unit testing, because it will fail to work
# without special DebugLink interface provided
# by the device.
def __init__(
self,
transport: Transport,
auto_interact: bool = True,
open_transport: bool = True,
debug_transport: Transport | None = None,
) -> None:
try:
debug_transport = debug_transport or transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact)
if open_transport:
self.debug.open()
# try to open debuglink, see if it works
assert self.debug.transport.ping()
except Exception:
if not auto_interact:
self.debug = NullDebugLink()
else:
raise
if open_transport:
transport.open()
# set transport explicitly so that sync_responses can work
super().__init__(transport)
self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
def get_pin(_msg: messages.PinMatrixRequest) -> str:
try:
pin = self.ui.get_pin()
except Cancelled:
raise
return pin
self.pin_callback = get_pin
self.button_callback = self.ui.button_request
self.sync_responses()
# So that we can choose right screenshotting logic (T1 vs TT)
# and know the supported debug capabilities
self.debug.model = self.model
self.debug.version = self.version
self.reset_debug_features()
@property
def layout_type(self) -> LayoutType:
return self.debug.layout_type
def get_new_client(self) -> TrezorClientDebugLink:
new_client = TrezorClientDebugLink(
self.transport,
self.debug.allow_interactions,
open_transport=False,
debug_transport=self.debug.transport,
)
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
return new_client
def close_transport(self) -> None:
self.transport.close()
self.debug.close()
def lock(self) -> None:
s = self.get_seedless_session()
s.lock()
def get_session(
self,
passphrase: str | object = "",
derive_cardano: bool = False,
) -> SessionDebugWrapper:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
session = SessionDebugWrapper(
super().get_session(
passphrase,
derive_cardano,
)
)
return session
# FIXME: can be deleted
def get_seedless_session(
self, *args: t.Any, **kwargs: t.Any
) -> SessionDebugWrapper:
session = super().get_seedless_session(*args, **kwargs)
if not isinstance(session, SessionDebugWrapper):
session = SessionDebugWrapper(session)
return session
def watch_layout(self, watch: bool = True) -> None:
"""Enable or disable watching layout changes.
Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before
using `debug.wait_layout()`, otherwise layout changes are not reported.
"""
if self.version >= (2, 3, 2):
# version check is necessary because otherwise we cannot reliably detect
# whether and where to wait for reply:
# - T1 reports unknown debuglink messages on the wirelink
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
self.debug.watch_layout(watch)
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts.
"""
self.ui.pins = iter(pins)
def use_mnemonic(self, mnemonic: str) -> None:
"""Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def sync_responses(self) -> None:
"""Synchronize Trezor device receiving with caller.
When a failed test does not read out the response, the next caller will write
a request, but read the previous response -- while the device had already sent
and placed into queue the new response.
This function will call `Ping` and read responses until it locates a `Success`
with the expected text. This means that we are reading up-to-date responses.
"""
import secrets
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
# prompt, which is in TINY mode and does not respond to `Ping`.
if self.protocol_version is ProtocolVersion.V1:
assert isinstance(self.protocol, ProtocolV1Channel)
self.protocol.write(messages.Cancel())
resp = self.protocol.read()
message = "SYNC" + secrets.token_hex(8)
self.protocol.write(messages.Ping(message=message))
while resp != messages.Success(message=message):
try:
resp = self.protocol.read()
except Exception:
pass
def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word()
if word:
return word
if pos:
return self.mnemonic[pos - 1]
raise RuntimeError("Unexpected call")
def set_expected_responses(
self,
expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]],
) -> None:
"""Set a sequence of expected responses to session calls.
Within a given with-block, the list of received responses from device must
match the list of expected responses, otherwise an ``AssertionError`` is raised.
If an expected response is given a field value other than ``None``, that field value
must exactly match the received field value. If a given field is ``None``
(or unspecified) in the expected response, the received field value is not
checked.
Each expected response can also be a tuple ``(bool, message)``. In that case, the
expected response is only evaluated if the first field is ``True``.
This is useful for differentiating sequences between Trezor models:
>>> trezor_one = session.features.model == "1"
>>> client.set_expected_responses([
>>> messages.ButtonRequest(code=ConfirmOutput),
>>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)),
>>> messages.Success(),
>>> ])
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
# make sure all items are (bool, message) tuples
expected_with_validity = (
e if isinstance(e, tuple) else (True, e) for e in expected
)
# only apply those items that are (True, message)
self.expected_responses = [
MessageFilter.from_message_or_type(expected)
for valid, expected in expected_with_validity
if valid
]
self.actual_responses = []
def set_filter( def set_filter(
self, self,
message_type: t.Type[protobuf.MessageType], message_type: t.Type[protobuf.MessageType],
@ -1069,6 +1239,7 @@ class SessionDebugWrapper(Session):
Clears all debugging state that might have been modified by a testcase. Clears all debugging state that might have been modified by a testcase.
""" """
self.ui.clear()
self.in_with_statement = False self.in_with_statement = False
self.expected_responses: list[MessageFilter] | None = None self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None self.actual_responses: list[protobuf.MessageType] | None = None
@ -1077,7 +1248,7 @@ class SessionDebugWrapper(Session):
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {} ] = {}
def __enter__(self) -> "SessionDebugWrapper": def __enter__(self) -> "TrezorClientDebugLink":
# For usage in with/expected_responses # For usage in with/expected_responses
if self.in_with_statement: if self.in_with_statement:
raise RuntimeError("Do not nest!") raise RuntimeError("Do not nest!")
@ -1092,10 +1263,8 @@ class SessionDebugWrapper(Session):
actual_responses = self.actual_responses actual_responses = self.actual_responses
# grab a copy of the inputflow generator to raise an exception through it # grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.client, TrezorClientDebugLink) and isinstance( if isinstance(self.ui, DebugUI):
self.client.ui, DebugUI input_flow = self.ui.input_flow
):
input_flow = self.client.ui.input_flow
else: else:
input_flow = None input_flow = None
@ -1105,7 +1274,6 @@ class SessionDebugWrapper(Session):
# If no other exception was raised, evaluate missed responses # If no other exception was raised, evaluate missed responses
# (raises AssertionError on mismatch) # (raises AssertionError on mismatch)
self._verify_responses(expected_responses, actual_responses) self._verify_responses(expected_responses, actual_responses)
elif isinstance(input_flow, t.Generator): elif isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in # Propagate the exception through the input flow, so that we see in
# traceback where it is stuck. # traceback where it is stuck.
@ -1165,205 +1333,9 @@ class SessionDebugWrapper(Session):
output.append("") output.append("")
return output return output
class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses
# and other functionality for unit tests
# for various callbacks, created in order
# to automatically pass unit tests.
#
# This mixing should be used only for purposes
# of unit testing, because it will fail to work
# without special DebugLink interface provided
# by the device.
def __init__(
self,
transport: Transport,
auto_interact: bool = True,
open_transport: bool = True,
debug_transport: Transport | None = None,
) -> None:
try:
debug_transport = debug_transport or transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact)
if open_transport:
self.debug.open()
# try to open debuglink, see if it works
assert self.debug.transport.ping()
except Exception:
if not auto_interact:
self.debug = NullDebugLink()
else:
raise
if open_transport:
transport.open()
# set transport explicitly so that sync_responses can work
super().__init__(transport)
self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
self.sync_responses()
# So that we can choose right screenshotting logic (T1 vs TT)
# and know the supported debug capabilities
self.debug.model = self.model
self.debug.version = self.version
@property
def layout_type(self) -> LayoutType:
return self.debug.layout_type
def get_new_client(self) -> TrezorClientDebugLink:
new_client = TrezorClientDebugLink(
self.transport,
self.debug.allow_interactions,
open_transport=False,
debug_transport=self.debug.transport,
)
new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir
new_client.debug.t1_screenshot_directory = self.debug.t1_screenshot_directory
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
return new_client
def reset_debug_features(self) -> None:
"""
Prepare the debugging client for a new testcase.
Clears all debugging state that might have been modified by a testcase.
"""
self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False
def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any:
try:
pin = self.ui.get_pin(msg.type)
except Cancelled:
session.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
session.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
def passphrase_callback(
self, session: Session, msg: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
if resp.state is not None:
session.id = resp.state
else:
raise RuntimeError("Object resp.state is None")
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
if isinstance(session, SessionDebugWrapper):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
if passphrase is None:
passphrase = session.passphrase
else:
raise NotImplementedError
except Cancelled:
session.call_raw(messages.Cancel())
raise
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if passphrase is None:
passphrase = ""
if not isinstance(passphrase, str):
raise RuntimeError(f"Passphrase must be a str {type(passphrase)}")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
def close_transport(self) -> None:
self.transport.close()
self.debug.close()
def lock(self) -> None:
s = self.get_seedless_session()
s.lock()
def get_session(
self,
passphrase: str | object | None = None,
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionDebugWrapper:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
session = SessionDebugWrapper(
super().get_session(
passphrase, derive_cardano, session_id, should_derive=False
)
)
session.passphrase = passphrase
return session
def get_seedless_session(
self, *args: t.Any, **kwargs: t.Any
) -> SessionDebugWrapper:
session = super().get_seedless_session(*args, **kwargs)
if not isinstance(session, SessionDebugWrapper):
session = SessionDebugWrapper(session)
return session
def resume_session(self, session: Session) -> SessionDebugWrapper:
if isinstance(session, SessionDebugWrapper):
session._session = super().resume_session(session._session)
return session
else:
return SessionDebugWrapper(super().resume_session(session))
def set_input_flow( def set_input_flow(
self, input_flow: InputFlowType | t.Callable[[], InputFlowType] self,
input_flow: InputFlowType | t.Callable[[], InputFlowType],
) -> None: ) -> None:
"""Configure a sequence of input events for the current with-block. """Configure a sequence of input events for the current with-block.
@ -1387,7 +1359,7 @@ class TrezorClientDebugLink(TrezorClient):
>>> >>>
>>> with client: >>> with client:
>>> client.set_input_flow(input_flow) >>> client.set_input_flow(input_flow)
>>> some_call(client) >>> some_call(session)
""" """
if not self.in_with_statement: if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement") raise RuntimeError("Must be called inside 'with' statement")
@ -1397,109 +1369,9 @@ class TrezorClientDebugLink(TrezorClient):
if not hasattr(input_flow, "send"): if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function") raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow self.ui.input_flow = input_flow
next(input_flow) # start the generator next(input_flow) # start the generator
def watch_layout(self, watch: bool = True) -> None:
"""Enable or disable watching layout changes.
Since trezor-core v2.3.2, it is necessary to call `watch_layout()` before
using `debug.wait_layout()`, otherwise layout changes are not reported.
"""
if self.version >= (2, 3, 2):
# version check is necessary because otherwise we cannot reliably detect
# whether and where to wait for reply:
# - T1 reports unknown debuglink messages on the wirelink
# - TT < 2.3.0 does not reply to unknown debuglink messages due to a bug
self.debug.watch_layout(watch)
def __enter__(self) -> "TrezorClientDebugLink":
# For usage in with/expected_responses
if self.in_with_statement:
raise RuntimeError("Do not nest!")
self.in_with_statement = True
return self
def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# grab a copy of the inputflow generator to raise an exception through it
if isinstance(self.ui, DebugUI):
input_flow = self.ui.input_flow
else:
input_flow = None
self.reset_debug_features()
if exc_type is not None and isinstance(input_flow, t.Generator):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
def use_pin_sequence(self, pins: t.Iterable[str]) -> None:
"""Respond to PIN prompts from device with the provided PINs.
The sequence must be at least as long as the expected number of PIN prompts.
"""
self.ui.pins = iter(pins)
def use_mnemonic(self, mnemonic: str) -> None:
"""Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: list[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
for i in range(start_at, stop_at):
exp = expected[i]
prefix = " " if i != current else ">>> "
output.append(textwrap.indent(exp.to_string(), prefix))
if stop_at < len(expected):
omitted = len(expected) - stop_at
output.append(f" (...{omitted} following responses omitted)")
output.append("")
return output
def sync_responses(self) -> None:
"""Synchronize Trezor device receiving with caller.
When a failed test does not read out the response, the next caller will write
a request, but read the previous response -- while the device had already sent
and placed into queue the new response.
This function will call `Ping` and read responses until it locates a `Success`
with the expected text. This means that we are reading up-to-date responses.
"""
import secrets
# Start by canceling whatever is on screen. This will work to cancel T1 PIN
# prompt, which is in TINY mode and does not respond to `Ping`.
if self.protocol_version is ProtocolVersion.V1:
assert isinstance(self.protocol, ProtocolV1Channel)
self.protocol.write(messages.Cancel())
resp = self.protocol.read()
message = "SYNC" + secrets.token_hex(8)
self.protocol.write(messages.Ping(message=message))
while resp != messages.Success(message=message):
try:
resp = self.protocol.read()
except Exception:
pass
def mnemonic_callback(self, _) -> str:
word, pos = self.debug.read_recovery_word()
if word:
return word
if pos:
return self.mnemonic[pos - 1]
raise RuntimeError("Unexpected call")
def load_device( def load_device(
session: "Session", session: "Session",

View File

@ -56,7 +56,7 @@ def pin_input_flow(client: Client, old_pin: str, new_pin: str):
if __name__ == "__main__": if __name__ == "__main__":
wirelink = get_device() wirelink = get_device()
client = Client(wirelink) client = Client(wirelink)
client.open() session = client.get_seedless_session()
i = 0 i = 0
@ -83,3 +83,5 @@ if __name__ == "__main__":
print(f"iteration {i}") print(f"iteration {i}")
i = i + 1 i = i + 1
wirelink.close()

View File

@ -195,13 +195,15 @@ def test_autolock_does_not_interrupt_signing(device_handler: "BackgroundDeviceHa
assert TR.send__total_amount in layout.text_content() assert TR.send__total_amount in layout.text_content()
assert "0.0039 BTC" in layout.text_content() assert "0.0039 BTC" in layout.text_content()
client = session.client
def sleepy_filter(msg: MessageType) -> MessageType: def sleepy_filter(msg: MessageType) -> MessageType:
time.sleep(10.1) time.sleep(10.1)
session.set_filter(messages.TxAck, None) client.set_filter(messages.TxAck, None)
return msg return msg
with session, device_handler.client: with client:
session.set_filter(messages.TxAck, sleepy_filter) client.set_filter(messages.TxAck, sleepy_filter)
# confirm transaction # confirm transaction
if debug.layout_type is LayoutType.Bolt: if debug.layout_type is LayoutType.Bolt:
debug.click(debug.screen_buttons.ok(), hold_ms=1000) debug.click(debug.screen_buttons.ok(), hold_ms=1000)
@ -546,15 +548,17 @@ def test_autolock_does_not_interrupt_preauthorized(
no_fee_indices=[], no_fee_indices=[],
) )
client = session.client
def sleepy_filter(msg: MessageType) -> MessageType: def sleepy_filter(msg: MessageType) -> MessageType:
time.sleep(10.1) time.sleep(10.1)
session.set_filter(messages.SignTx, None) client.set_filter(messages.SignTx, None)
return msg return msg
with session: with client:
# Start DoPreauthorized flow when device is unlocked. Wait 10s before # Start DoPreauthorized flow when device is unlocked. Wait 10s before
# delivering SignTx, by that time autolock timer should have fired. # delivering SignTx, by that time autolock timer should have fired.
session.set_filter(messages.SignTx, sleepy_filter) client.set_filter(messages.SignTx, sleepy_filter)
device_handler.run_with_provided_session( device_handler.run_with_provided_session(
session, session,
btc.sign_tx, btc.sign_tx,

View File

@ -22,6 +22,10 @@ udp.SOCKET_TIMEOUT = 0.1
class NullUI: class NullUI:
@staticmethod
def clear(*args, **kwargs):
pass
@staticmethod @staticmethod
def button_request(code): def button_request(code):
pass pass

View File

@ -52,7 +52,7 @@ def test_binance_get_address_chunkify_details(
# data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50 # data from https://github.com/binance-chain/javascript-sdk/blob/master/__tests__/crypto.test.js#L50
with session.client as client: with session.client as client:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address = get_address( address = get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True

View File

@ -33,7 +33,7 @@ BINANCE_PATH = parse_path("m/44h/714h/0h/0/0")
) )
def test_binance_get_public_key(session: Session): def test_binance_get_public_key(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
sig = binance.get_public_key(session, BINANCE_PATH, show_display=True) sig = binance.get_public_key(session, BINANCE_PATH, show_display=True)
assert ( assert (

View File

@ -66,7 +66,7 @@ def test_sign_tx(session: Session, chunkify: bool):
commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big") commitment_data = b"\x0fwww.example.com" + (1).to_bytes(ROUND_ID_LEN, "big")
with session.client as client: with session.client as client:
client.use_pin_sequence([PIN]) session.client.use_pin_sequence([PIN])
btc.authorize_coinjoin( btc.authorize_coinjoin(
session, session,
coordinator="www.example.com", coordinator="www.example.com",
@ -80,8 +80,8 @@ def test_sign_tx(session: Session, chunkify: bool):
session.call(messages.LockDevice()) session.call(messages.LockDevice())
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[messages.PreauthorizedRequest, messages.OwnershipProof] [messages.PreauthorizedRequest, messages.OwnershipProof]
) )
btc.get_ownership_proof( btc.get_ownership_proof(
@ -94,8 +94,8 @@ def test_sign_tx(session: Session, chunkify: bool):
preauthorized=True, preauthorized=True,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[messages.PreauthorizedRequest, messages.OwnershipProof] [messages.PreauthorizedRequest, messages.OwnershipProof]
) )
btc.get_ownership_proof( btc.get_ownership_proof(
@ -207,8 +207,8 @@ def test_sign_tx(session: Session, chunkify: bool):
no_fee_indices=[], no_fee_indices=[],
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.PreauthorizedRequest(), messages.PreauthorizedRequest(),
request_input(0), request_input(0),
@ -452,8 +452,8 @@ def test_sign_tx_spend(session: Session):
prev_txes=TX_CACHE_TESTNET, prev_txes=TX_CACHE_TESTNET,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=B.Other), messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest, messages.UnlockedPathRequest,
@ -526,8 +526,8 @@ def test_sign_tx_migration(session: Session):
prev_txes=TX_CACHE_TESTNET, prev_txes=TX_CACHE_TESTNET,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=B.Other), messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest, messages.UnlockedPathRequest,
@ -666,8 +666,8 @@ def test_get_public_key(session: Session):
) )
# Get unlock path MAC. # Get unlock path MAC.
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=B.Other), messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest, messages.UnlockedPathRequest,
@ -689,8 +689,8 @@ def test_get_public_key(session: Session):
) )
# Ensure that user does not need to confirm access when path unlock is requested with MAC. # Ensure that user does not need to confirm access when path unlock is requested with MAC.
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.UnlockedPathRequest, messages.UnlockedPathRequest,
messages.PublicKey, messages.PublicKey,
@ -720,8 +720,8 @@ def test_get_address(session: Session):
) )
# Unlock CoinJoin path. # Unlock CoinJoin path.
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=B.Other), messages.ButtonRequest(code=B.Other),
messages.UnlockedPathRequest, messages.UnlockedPathRequest,

View File

@ -72,8 +72,8 @@ def test_send_bch_change(session: Session):
amount=73_452, amount=73_452,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -124,8 +124,8 @@ def test_send_bch_nochange(session: Session):
amount=1_934_960, amount=1_934_960,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -182,8 +182,8 @@ def test_send_bch_oldaddr(session: Session):
amount=1_934_960, amount=1_934_960,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -252,9 +252,9 @@ def test_attack_change_input(session: Session):
return msg return msg
with session: with session.client as client:
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -327,8 +327,8 @@ def test_send_bch_multisig_wrongchange(session: Session):
script_type=messages.OutputScriptType.PAYTOMULTISIG, script_type=messages.OutputScriptType.PAYTOMULTISIG,
amount=23_000, amount=23_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -395,8 +395,8 @@ def test_send_bch_multisig_change(session: Session):
script_type=messages.OutputScriptType.PAYTOMULTISIG, script_type=messages.OutputScriptType.PAYTOMULTISIG,
amount=24_000, amount=24_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -434,8 +434,8 @@ def test_send_bch_multisig_change(session: Session):
) )
out2.address_n[2] = H_(1) out2.address_n[2] = H_(1)
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -496,8 +496,8 @@ def test_send_bch_external_presigned(session: Session):
amount=1_934_960, amount=1_934_960,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),

View File

@ -71,8 +71,8 @@ def test_send_bitcoin_gold_change(session: Session):
amount=1_252_382_934 - 1_896_050 - 1_000, amount=1_252_382_934 - 1_896_050 - 1_000,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -124,8 +124,8 @@ def test_send_bitcoin_gold_nochange(session: Session):
amount=1_252_382_934 + 38_448_607 - 1_000, amount=1_252_382_934 + 38_448_607 - 1_000,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -193,9 +193,9 @@ def test_attack_change_input(session: Session):
return msg return msg
with session: with session.client as client:
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -254,8 +254,8 @@ def test_send_btg_multisig_change(session: Session):
script_type=messages.OutputScriptType.PAYTOMULTISIG, script_type=messages.OutputScriptType.PAYTOMULTISIG,
amount=1_252_382_934 - 24_000 - 1_000, amount=1_252_382_934 - 24_000 - 1_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -293,8 +293,8 @@ def test_send_btg_multisig_change(session: Session):
) )
out2.address_n[2] = H_(1) out2.address_n[2] = H_(1)
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -347,8 +347,8 @@ def test_send_p2sh(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
amount=1_252_382_934 - 11_000 - 12_300_000, amount=1_252_382_934 - 11_000 - 12_300_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -400,8 +400,8 @@ def test_send_p2sh_witness_change(session: Session):
script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, script_type=messages.OutputScriptType.PAYTOP2SHWITNESS,
amount=1_252_382_934 - 11_000 - 12_300_000, amount=1_252_382_934 - 11_000 - 12_300_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -460,8 +460,8 @@ def test_send_multisig_1(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -484,7 +484,7 @@ def test_send_multisig_1(session: Session):
inp1.multisig.signatures[0] = signatures[0] inp1.multisig.signatures[0] = signatures[0]
# sign with third key # sign with third key
inp1.address_n[2] = H_(3) inp1.address_n[2] = H_(3)
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -537,7 +537,7 @@ def test_send_mixed_inputs(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API session, "Bgold", [inp1, inp2], [out1], prev_txes=TX_API
) )
@ -577,8 +577,8 @@ def test_send_btg_external_presigned(session: Session):
amount=1_252_382_934 + 58_456 - 1_000, amount=1_252_382_934 + 58_456 - 1_000,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),

View File

@ -57,8 +57,8 @@ def test_send_dash(session: Session):
amount=999_999_000, amount=999_999_000,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -106,8 +106,8 @@ def test_send_dash_dip2_input(session: Session):
amount=95_000_000, amount=95_000_000,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),

View File

@ -76,8 +76,8 @@ def test_send_decred(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -133,8 +133,8 @@ def test_purchase_ticket_decred(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -197,8 +197,8 @@ def test_spend_from_stake_generation_and_revocation_decred(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -278,8 +278,8 @@ def test_send_decred_change(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -384,8 +384,8 @@ def test_decred_multisig_change(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),

View File

@ -169,7 +169,7 @@ def test_descriptors(
session: Session, coin, account, purpose, script_type, descriptors session: Session, coin, account, purpose, script_type, descriptors
): ):
with session.client as client: with session.client as client:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address_n = _address_n(purpose, coin, account, script_type) address_n = _address_n(purpose, coin, account, script_type)
@ -192,8 +192,8 @@ def test_descriptors_trezorlib(
session: Session, coin, account, purpose, script_type, descriptors session: Session, coin, account, purpose, script_type, descriptors
): ):
with session.client as client: with session.client as client:
if client.model != models.T1B1: if session.client.model != models.T1B1:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
res = btc_cli._get_descriptor( res = btc_cli._get_descriptor(
session, coin, account, purpose, script_type, show_display=True session, coin, account, purpose, script_type, show_display=True

View File

@ -272,7 +272,7 @@ def test_multisig(session: Session):
for nr in range(1, 4): for nr in range(1, 4):
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
@ -321,9 +321,9 @@ def test_multisig_missing(session: Session, show_display):
) )
for multisig in (multisig1, multisig2): for multisig in (multisig1, multisig2):
with session.client as client, pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure), session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,
@ -347,7 +347,7 @@ def test_bch_multisig(session: Session):
for nr in range(1, 4): for nr in range(1, 4):
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
@ -396,8 +396,8 @@ def test_invalid_path(session: Session):
def test_unknown_path(session: Session): def test_unknown_path(session: Session):
UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0") UNKNOWN_PATH = parse_path("m/44h/9h/0h/0/0")
with session: with session.client as client:
session.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
with pytest.raises(TrezorFailure, match="Forbidden key path"): with pytest.raises(TrezorFailure, match="Forbidden key path"):
# account number is too high # account number is too high
@ -406,8 +406,8 @@ def test_unknown_path(session: Session):
# disable safety checks # disable safety checks
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with session, session.client as client: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest( messages.ButtonRequest(
code=messages.ButtonRequestType.UnknownDerivationPath code=messages.ButtonRequestType.UnknownDerivationPath
@ -417,14 +417,14 @@ def test_unknown_path(session: Session):
] ]
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# try again with a warning # try again with a warning
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True) btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=True)
with session: with session.client as client:
# no warning is displayed when the call is silent # no warning is displayed when the call is silent
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False) btc.get_address(session, "Bitcoin", UNKNOWN_PATH, show_display=False)
@ -455,9 +455,9 @@ def test_multisig_different_paths(session: Session):
with pytest.raises( with pytest.raises(
Exception, match="Using different paths for different xpubs is not allowed" Exception, match="Using different paths for different xpubs is not allowed"
): ):
with session.client as client, session: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,
@ -471,7 +471,7 @@ def test_multisig_different_paths(session: Session):
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,

View File

@ -76,7 +76,7 @@ def test_show_segwit(session: Session):
def test_show_segwit_altcoin(session: Session): def test_show_segwit_altcoin(session: Session):
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(

View File

@ -89,7 +89,7 @@ def test_show_tt(
address: str, address: str,
): ):
with session.client as client: with session.client as client:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
@ -110,7 +110,7 @@ def test_show_cancel(
session: Session, path: str, script_type: messages.InputScriptType, address: str session: Session, path: str, script_type: messages.InputScriptType, address: str
): ):
with session.client as client, pytest.raises(Cancelled): with session.client as client, pytest.raises(Cancelled):
IF = InputFlowShowAddressQRCodeCancel(client) IF = InputFlowShowAddressQRCodeCancel(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.get_address( btc.get_address(
session, session,
@ -159,7 +159,7 @@ def test_show_multisig_3(session: Session):
for i in [1, 2, 3]: for i in [1, 2, 3]:
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(
@ -273,11 +273,11 @@ def test_show_multisig_xpubs(
) )
for i in range(3): for i in range(3):
with session, session.client as client: with session.client as client:
IF = InputFlowShowMultisigXPUBs(client, address, xpubs, i) IF = InputFlowShowMultisigXPUBs(session.client, address, xpubs, i)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
client.debug.synchronize_at("Homescreen") session.client.debug.synchronize_at("Homescreen")
client.watch_layout() session.client.watch_layout()
btc.get_address( btc.get_address(
session, session,
"Bitcoin", "Bitcoin",
@ -316,7 +316,7 @@ def test_show_multisig_15(session: Session):
for i in range(15): for i in range(15):
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
assert ( assert (
btc.get_address( btc.get_address(

View File

@ -120,7 +120,7 @@ def test_get_public_node(session: Session, coin_name, xpub_magic, path, xpub):
@pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN) @pytest.mark.parametrize("coin_name, xpub_magic, path, xpub", VECTORS_BITCOIN)
def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub): def test_get_public_node_show(session: Session, coin_name, xpub_magic, path, xpub):
with session.client as client: with session.client as client:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub assert res.xpub == xpub
@ -158,7 +158,7 @@ def test_get_public_node_show_legacy(
client.debug.press_yes() # finish the flow client.debug.press_yes() # finish the flow
yield yield
with client: with session.client as client:
# test XPUB display flow (without showing QR code) # test XPUB display flow (without showing QR code)
res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True) res = btc.get_public_node(session, path, coin_name=coin_name, show_display=True)
assert res.xpub == xpub assert res.xpub == xpub

View File

@ -61,8 +61,8 @@ def test_one_one_fee_sapling(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -125,8 +125,8 @@ def test_one_one_rewards_claim(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),

View File

@ -101,8 +101,8 @@ def test_2_of_3(session: Session, chunkify: bool):
request_finished(), request_finished(),
] ]
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
# Now we have first signature # Now we have first signature
signatures1, _ = btc.sign_tx( signatures1, _ = btc.sign_tx(
@ -143,8 +143,8 @@ def test_2_of_3(session: Session, chunkify: bool):
multisig=multisig, multisig=multisig,
) )
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
signatures2, serialized_tx = btc.sign_tx( signatures2, serialized_tx = btc.sign_tx(
session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp3], [out1], prev_txes=TX_API_TESTNET
) )
@ -362,7 +362,7 @@ def test_15_of_15(session: Session):
multisig=multisig, multisig=multisig,
) )
with session: with session.client:
sig, serialized_tx = btc.sign_tx( sig, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -424,6 +424,7 @@ def test_attack_change_input(session: Session):
attacker to provide a 1-of-2 multisig change address. When `input_real` attacker to provide a 1-of-2 multisig change address. When `input_real`
is provided in the signing phase, an error must occur. is provided in the signing phase, an error must occur.
""" """
client = session.client
address_n = parse_path("m/48h/1h/0h/1h/0/0") # 2NErUdruXmM8o8bQySrzB3WdBRcmc5br4E8 address_n = parse_path("m/48h/1h/0h/1h/0/0") # 2NErUdruXmM8o8bQySrzB3WdBRcmc5br4E8
attacker_multisig_public_key = bytes.fromhex( attacker_multisig_public_key = bytes.fromhex(
"03653a148b68584acb97947344a7d4fd6a6f8b8485cad12987ff8edac874268088" "03653a148b68584acb97947344a7d4fd6a6f8b8485cad12987ff8edac874268088"
@ -475,7 +476,7 @@ def test_attack_change_input(session: Session):
) )
# Transaction can be signed without the attack processor # Transaction can be signed without the attack processor
with session.client as client: with client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
@ -497,8 +498,8 @@ def test_attack_change_input(session: Session):
attack_count -= 1 attack_count -= 1
return msg return msg
with session: with client:
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
with pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure):
btc.sign_tx( btc.sign_tx(
session, session,

View File

@ -263,8 +263,8 @@ def test_external_external(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses(_responses(session, INP1, INP2)) client.set_expected_responses(_responses(session, INP1, INP2))
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",
@ -288,8 +288,8 @@ def test_external_internal(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session, session.client as client: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
_responses( _responses(
session, session,
INP1, INP1,
@ -299,7 +299,7 @@ def test_external_internal(session: Session):
) )
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
@ -324,8 +324,8 @@ def test_internal_external(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session, session.client as client: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
_responses( _responses(
session, session,
INP1, INP1,
@ -335,7 +335,7 @@ def test_internal_external(session: Session):
) )
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
session, session,
@ -360,8 +360,8 @@ def test_multisig_external_external(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses(_responses(session, INP1, INP2)) client.set_expected_responses(_responses(session, INP1, INP2))
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",
@ -393,8 +393,8 @@ def test_multisig_change_match_first(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
_responses(session, INP1, INP2, change_indices=[1]) _responses(session, INP1, INP2, change_indices=[1])
) )
btc.sign_tx( btc.sign_tx(
@ -428,8 +428,8 @@ def test_multisig_change_match_second(session: Session):
script_type=messages.OutputScriptType.PAYTOMULTISIG, script_type=messages.OutputScriptType.PAYTOMULTISIG,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
_responses(session, INP1, INP2, change_indices=[2]) _responses(session, INP1, INP2, change_indices=[2])
) )
btc.sign_tx( btc.sign_tx(
@ -464,8 +464,8 @@ def test_sorted_multisig_change_match_first(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
_responses(session, INP4, INP5, change_indices=[1]) _responses(session, INP4, INP5, change_indices=[1])
) )
btc.sign_tx( btc.sign_tx(
@ -499,8 +499,8 @@ def test_multisig_mismatch_multisig_change(session: Session):
script_type=messages.OutputScriptType.PAYTOMULTISIG, script_type=messages.OutputScriptType.PAYTOMULTISIG,
) )
with session: with session.client as client:
session.set_expected_responses(_responses(session, INP1, INP2)) client.set_expected_responses(_responses(session, INP1, INP2))
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",
@ -532,8 +532,8 @@ def test_sorted_multisig_mismatch_multisig_change(session: Session):
script_type=messages.OutputScriptType.PAYTOMULTISIG, script_type=messages.OutputScriptType.PAYTOMULTISIG,
) )
with session: with session.client as client:
session.set_expected_responses(_responses(session, INP4, INP5)) client.set_expected_responses(_responses(session, INP4, INP5))
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",
@ -568,8 +568,8 @@ def test_multisig_mismatch_multisig_change_different_paths(session: Session):
script_type=messages.OutputScriptType.PAYTOMULTISIG, script_type=messages.OutputScriptType.PAYTOMULTISIG,
) )
with session: with session.client as client:
session.set_expected_responses(_responses(session, INP1, INP2)) client.set_expected_responses(_responses(session, INP1, INP2))
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",
@ -601,8 +601,8 @@ def test_multisig_mismatch_inputs(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses(_responses(session, INP1, INP3)) client.set_expected_responses(_responses(session, INP1, INP3))
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",
@ -635,8 +635,8 @@ def test_sorted_multisig_mismatch_inputs(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses(_responses(session, INP4, INP6)) client.set_expected_responses(_responses(session, INP4, INP6))
btc.sign_tx( btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",

View File

@ -115,7 +115,7 @@ def test_getaddress(
for script_type in script_types: for script_type in script_types:
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
res = btc.get_address( res = btc.get_address(
session, session,
@ -136,7 +136,7 @@ def test_signmessage(
for script_type in script_types: for script_type in script_types:
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
sig = btc.sign_message( sig = btc.sign_message(
@ -177,7 +177,7 @@ def test_signtx(
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}
@ -204,7 +204,7 @@ def test_getaddress_multisig(
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address = btc.get_address( address = btc.get_address(
session, session,
@ -263,7 +263,7 @@ def test_signtx_multisig(session: Session, paths: list[str], address_index: list
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
sig, _ = btc.sign_tx( sig, _ = btc.sign_tx(
session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx} session, "Bitcoin", [inp1], [out1], prev_txes={prevhash: prevtx}

View File

@ -63,8 +63,8 @@ def test_opreturn(session: Session):
script_type=messages.OutputScriptType.PAYTOOPRETURN, script_type=messages.OutputScriptType.PAYTOOPRETURN,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -110,8 +110,8 @@ def test_nonzero_opreturn(session: Session):
script_type=messages.OutputScriptType.PAYTOOPRETURN, script_type=messages.OutputScriptType.PAYTOOPRETURN,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[request_input(0), request_output(0), messages.Failure()] [request_input(0), request_output(0), messages.Failure()]
) )
@ -136,8 +136,8 @@ def test_opreturn_address(session: Session):
script_type=messages.OutputScriptType.PAYTOOPRETURN, script_type=messages.OutputScriptType.PAYTOOPRETURN,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[request_input(0), request_output(0), messages.Failure()] [request_input(0), request_output(0), messages.Failure()]
) )
with pytest.raises( with pytest.raises(

View File

@ -328,7 +328,7 @@ def test_signmessage_long(
signature: str, signature: str,
): ):
with session.client as client: with session.client as client:
IF = InputFlowSignVerifyMessageLong(client) IF = InputFlowSignVerifyMessageLong(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
sig = btc.sign_message( sig = btc.sign_message(
session, session,
@ -357,7 +357,7 @@ def test_signmessage_info(
signature: str, signature: str,
): ):
with session.client as client, pytest.raises(Cancelled): with session.client as client, pytest.raises(Cancelled):
IF = InputFlowSignMessageInfo(client) IF = InputFlowSignMessageInfo(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
sig = btc.sign_message( sig = btc.sign_message(
session, session,
@ -395,7 +395,7 @@ def test_signmessage_pagination(session: Session, message: str, is_long: bool):
InputFlowSignVerifyMessageLong InputFlowSignVerifyMessageLong
if is_long if is_long
else InputFlowSignMessagePagination else InputFlowSignMessagePagination
)(client) )(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_message( btc.sign_message(
session, session,
@ -417,8 +417,8 @@ def test_signmessage_pagination_trailing_newline(session: Session):
message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n" message = "THIS\nMUST\nNOT\nBE\nPAGINATED\n"
# The trailing newline must not cause a new paginated screen to appear. # The trailing newline must not cause a new paginated screen to appear.
# The UI must be a single dialog without pagination. # The UI must be a single dialog without pagination.
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
# expect address confirmation # expect address confirmation
message_filters.ButtonRequest(code=messages.ButtonRequestType.Other), message_filters.ButtonRequest(code=messages.ButtonRequestType.Other),
@ -438,8 +438,8 @@ def test_signmessage_pagination_trailing_newline(session: Session):
def test_signmessage_path_warning(session: Session): def test_signmessage_path_warning(session: Session):
message = "This is an example of a signed message." message = "This is an example of a signed message."
with session, session.client as client: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
# expect a path warning # expect a path warning
message_filters.ButtonRequest( message_filters.ButtonRequest(
@ -451,7 +451,7 @@ def test_signmessage_path_warning(session: Session):
] ]
) )
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_message( btc.sign_message(
session, session,

View File

@ -125,8 +125,8 @@ def test_one_one_fee(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -178,8 +178,8 @@ def test_testnet_one_two_fee(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -228,8 +228,8 @@ def test_testnet_fee_high_warning(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -280,8 +280,8 @@ def test_one_two_fee(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -342,8 +342,8 @@ def test_one_three_fee(session: Session, chunkify: bool):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -413,8 +413,8 @@ def test_two_two(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -557,8 +557,8 @@ def test_lots_of_change(session: Session):
request_change_outputs = [request_output(i + 1) for i in range(cnt)] request_change_outputs = [request_output(i + 1) for i in range(cnt)]
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -608,8 +608,8 @@ def test_fee_high_warning(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -665,7 +665,7 @@ def test_fee_high_hardfail(session: Session):
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
) )
with session.client as client: with session.client as client:
IF = InputFlowSignTxHighFee(client) IF = InputFlowSignTxHighFee(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
@ -696,8 +696,8 @@ def test_not_enough_funds(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -726,8 +726,8 @@ def test_p2sh(session: Session):
script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, script_type=messages.OutputScriptType.PAYTOSCRIPTHASH,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -785,6 +785,7 @@ def test_testnet_big_amount(session: Session):
def test_attack_change_outputs(session: Session): def test_attack_change_outputs(session: Session):
# input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a # input tx: ac4ca0e7827a1228f44449cb57b4b9a809a667ca044dc43bb124627fed4bc10a
client = session.client
inp1 = messages.TxInputType( inp1 = messages.TxInputType(
address_n=parse_path("m/44h/0h/0h/0/55"), # 14nw9rFTWGUncHZjSqpPSJQaptWW7iRRB8 address_n=parse_path("m/44h/0h/0h/0/55"), # 14nw9rFTWGUncHZjSqpPSJQaptWW7iRRB8
@ -813,8 +814,8 @@ def test_attack_change_outputs(session: Session):
) )
# Test if the transaction can be signed normally # Test if the transaction can be signed normally
with session: with client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -869,11 +870,11 @@ def test_attack_change_outputs(session: Session):
return msg return msg
with session, pytest.raises( with client, pytest.raises(
TrezorFailure, match="Transaction has changed during signing" TrezorFailure, match="Transaction has changed during signing"
): ):
# Set up attack processors # Set up attack processors
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
btc.sign_tx( btc.sign_tx(
session, session,
@ -924,11 +925,11 @@ def test_attack_modify_change_address(session: Session):
return msg return msg
with session, pytest.raises( with session.client as client, pytest.raises(
TrezorFailure, match="Transaction has changed during signing" TrezorFailure, match="Transaction has changed during signing"
): ):
# Set up attack processors # Set up attack processors
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
btc.sign_tx( btc.sign_tx(
session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET session, "Testnet", [inp1], [out1, out2], prev_txes=TX_CACHE_TESTNET
@ -982,9 +983,9 @@ def test_attack_change_input_address(session: Session):
return msg return msg
# Now run the attack, must trigger the exception # Now run the attack, must trigger the exception
with session: with session.client as client:
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -1033,8 +1034,8 @@ def test_spend_coinbase(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -1091,8 +1092,8 @@ def test_two_changes(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -1150,8 +1151,8 @@ def test_change_on_main_chain_allowed(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -1418,8 +1419,8 @@ def test_lock_time(session: Session, lock_time: int, sequence: int):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -1468,7 +1469,7 @@ def test_lock_time_blockheight(session: Session):
) )
with session.client as client: with session.client as client:
IF = InputFlowLockTimeBlockHeight(client, "499999999") IF = InputFlowLockTimeBlockHeight(session.client, "499999999")
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
@ -1507,7 +1508,7 @@ def test_lock_time_datetime(session: Session, lock_time_str: str):
lock_time_timestamp = int(lock_time_utc.timestamp()) lock_time_timestamp = int(lock_time_utc.timestamp())
with session.client as client: with session.client as client:
IF = InputFlowLockTimeDatetime(client, lock_time_str) IF = InputFlowLockTimeDatetime(session.client, lock_time_str)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
@ -1539,7 +1540,7 @@ def test_information(session: Session):
) )
with session.client as client: with session.client as client:
IF = InputFlowSignTxInformation(client) IF = InputFlowSignTxInformation(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
@ -1574,7 +1575,7 @@ def test_information_mixed(session: Session):
) )
with session.client as client: with session.client as client:
IF = InputFlowSignTxInformationMixed(client) IF = InputFlowSignTxInformationMixed(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
@ -1605,7 +1606,7 @@ def test_information_cancel(session: Session):
) )
with session.client as client, pytest.raises(Cancelled): with session.client as client, pytest.raises(Cancelled):
IF = InputFlowSignTxInformationCancel(client) IF = InputFlowSignTxInformationCancel(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(
@ -1653,7 +1654,7 @@ def test_information_replacement(session: Session):
) )
with session.client as client: with session.client as client:
IF = InputFlowSignTxInformationReplacement(client) IF = InputFlowSignTxInformationReplacement(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx( btc.sign_tx(

View File

@ -61,7 +61,7 @@ def test_signtx_testnet(session: Session, amount_unit):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
amount=100_000 - 40_000 - 10_000, amount=100_000 - 40_000 - 10_000,
) )
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,
"Testnet", "Testnet",
@ -95,7 +95,7 @@ def test_signtx_btc(session: Session, amount_unit):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,
"Bitcoin", "Bitcoin",

View File

@ -142,7 +142,7 @@ def test_p2pkh_presigned(session: Session):
) )
# Test with first input as pre-signed external. # Test with first input as pre-signed external.
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,
"Testnet", "Testnet",
@ -155,7 +155,7 @@ def test_p2pkh_presigned(session: Session):
assert serialized_tx.hex() == expected_tx assert serialized_tx.hex() == expected_tx
# Test with second input as pre-signed external. # Test with second input as pre-signed external.
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,
"Testnet", "Testnet",
@ -216,8 +216,8 @@ def test_p2wpkh_in_p2sh_presigned(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -267,8 +267,8 @@ def test_p2wpkh_in_p2sh_presigned(session: Session):
# Test corrupted script hash in scriptsig. # Test corrupted script hash in scriptsig.
inp1.script_sig[10] ^= 1 inp1.script_sig[10] ^= 1
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -339,7 +339,7 @@ def test_p2wpkh_presigned(session: Session):
) )
# Test with second input as pre-signed external. # Test with second input as pre-signed external.
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,
"Testnet", "Testnet",
@ -399,8 +399,8 @@ def test_p2wsh_external_presigned(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -444,8 +444,8 @@ def test_p2wsh_external_presigned(session: Session):
# Test corrupted signature in witness. # Test corrupted signature in witness.
inp2.witness[10] ^= 1 inp2.witness[10] ^= 1
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -509,8 +509,8 @@ def test_p2tr_external_presigned(session: Session):
amount=4_600, amount=4_600,
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -541,8 +541,8 @@ def test_p2tr_external_presigned(session: Session):
# Test corrupted signature in witness. # Test corrupted signature in witness.
inp2.witness[10] ^= 1 inp2.witness[10] ^= 1
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -610,9 +610,9 @@ def test_p2wpkh_with_proof(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
is_t1 = session.model is models.T1B1 is_t1 = session.model is models.T1B1
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -703,9 +703,9 @@ def test_p2tr_with_proof(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
is_t1 = session.model is models.T1B1 is_t1 = session.model is models.T1B1
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -770,8 +770,8 @@ def test_p2wpkh_with_false_proof(session: Session):
script_type=messages.OutputScriptType.PAYTOWITNESS, script_type=messages.OutputScriptType.PAYTOWITNESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),

View File

@ -82,7 +82,7 @@ def test_invalid_path_prompt(session: Session):
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES) btc.sign_tx(session, "Litecoin", [inp1], [out1], prev_txes=PREV_TXES)
@ -108,7 +108,7 @@ def test_invalid_path_pass_forkid(session: Session):
with session.client as client: with session.client as client:
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES) btc.sign_tx(session, "Bcash", [inp1], [out1], prev_txes=PREV_TXES)
@ -178,8 +178,8 @@ def test_attack_path_segwit(session: Session):
return msg return msg
with session: with session.client as client:
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
with pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure):
btc.sign_tx( btc.sign_tx(
session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx} session, "Testnet", [inp1, inp2], [out1], prev_txes={prev_hash: prev_tx}
@ -202,8 +202,8 @@ def test_invalid_path_fail_asap(session: Session):
script_type=messages.OutputScriptType.PAYTOWITNESS, script_type=messages.OutputScriptType.PAYTOWITNESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
messages.Failure(code=messages.FailureType.DataError), messages.Failure(code=messages.FailureType.DataError),

View File

@ -58,7 +58,7 @@ def test_non_segwit_segwit_inputs(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client:
signatures, serialized_tx = btc.sign_tx( signatures, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API
) )
@ -94,7 +94,7 @@ def test_segwit_non_segwit_inputs(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client:
signatures, serialized_tx = btc.sign_tx( signatures, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API session, "Testnet", [inp1, inp2], [out1], prev_txes=TX_API
) )
@ -138,7 +138,7 @@ def test_segwit_non_segwit_segwit_inputs(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client:
signatures, serialized_tx = btc.sign_tx( signatures, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API
) )
@ -180,7 +180,7 @@ def test_non_segwit_segwit_non_segwit_inputs(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client:
signatures, serialized_tx = btc.sign_tx( signatures, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API session, "Testnet", [inp1, inp2, inp3], [out1], prev_txes=TX_API
) )

View File

@ -204,7 +204,7 @@ def test_payment_request_details(session: Session):
] ]
with session.client as client: with session.client as client:
IF = InputFlowPaymentRequestDetails(client, outputs) IF = InputFlowPaymentRequestDetails(session.client, outputs)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(

View File

@ -130,8 +130,8 @@ def test_invalid_prev_hash_attack(session: Session, prev_hash):
msg.tx.inputs[0].prev_hash = prev_hash msg.tx.inputs[0].prev_hash = prev_hash
return msg return msg
with session, session.client as client, pytest.raises(TrezorFailure) as e: with session.client as client, pytest.raises(TrezorFailure) as e:
session.set_filter(messages.TxAck, attack_filter) client.set_filter(messages.TxAck, attack_filter)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
@ -168,9 +168,9 @@ def test_invalid_prev_hash_in_prevtx(session: Session, prev_hash):
tx_hash = hash_tx(serialize_tx(prev_tx)) tx_hash = hash_tx(serialize_tx(prev_tx))
inp0.prev_hash = tx_hash inp0.prev_hash = tx_hash
with session, session.client as client, pytest.raises(TrezorFailure) as e: with session.client as client, pytest.raises(TrezorFailure) as e:
if session.model is not models.T1B1: if session.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx}) btc.sign_tx(session, "Bitcoin", [inp0], [out1], prev_txes={tx_hash: prev_tx})
_check_error_message(prev_hash, session.model, e.value.message) _check_error_message(prev_hash, session.model, e.value.message)

View File

@ -116,8 +116,8 @@ def test_p2pkh_fee_bump(session: Session):
orig_index=1, orig_index=1,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_meta(TXHASH_50f6f1), request_meta(TXHASH_50f6f1),
@ -190,7 +190,7 @@ def test_p2wpkh_op_return_fee_bump(session: Session):
orig_index=1, orig_index=1,
) )
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,
"Testnet", "Testnet",
@ -243,8 +243,8 @@ def test_p2tr_fee_bump(session: Session):
orig_index=1, orig_index=1,
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_meta(TXHASH_8e4af7), request_meta(TXHASH_8e4af7),
@ -312,8 +312,8 @@ def test_p2wpkh_finalize(session: Session):
orig_index=1, orig_index=1,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_meta(TXHASH_70f987), request_meta(TXHASH_70f987),
@ -444,8 +444,8 @@ def test_p2wpkh_payjoin(
orig_index=1, orig_index=1,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_meta(TXHASH_65b768), request_meta(TXHASH_65b768),
@ -520,8 +520,8 @@ def test_p2wpkh_in_p2sh_remove_change(session: Session):
orig_index=0, orig_index=0,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_meta(TXHASH_334cd7), request_meta(TXHASH_334cd7),
@ -599,8 +599,8 @@ def test_p2wpkh_in_p2sh_fee_bump_from_external(session: Session):
orig_index=0, orig_index=0,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_meta(TXHASH_334cd7), request_meta(TXHASH_334cd7),
@ -720,8 +720,8 @@ def test_tx_meld(session: Session):
script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, script_type=messages.OutputScriptType.PAYTOP2SHWITNESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_meta(TXHASH_334cd7), request_meta(TXHASH_334cd7),

View File

@ -66,8 +66,8 @@ def test_send_p2sh(session: Session, chunkify: bool):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
amount=123_456_789 - 11_000 - 12_300_000, amount=123_456_789 - 11_000 - 12_300_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -124,8 +124,8 @@ def test_send_p2sh_change(session: Session):
script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, script_type=messages.OutputScriptType.PAYTOP2SHWITNESS,
amount=123_456_789 - 11_000 - 12_300_000, amount=123_456_789 - 11_000 - 12_300_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -179,8 +179,8 @@ def test_testnet_segwit_big_amount(session: Session):
amount=2**32 + 1, amount=2**32 + 1,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -254,8 +254,8 @@ def test_send_multisig_1(session: Session):
request_finished(), request_finished(),
] ]
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -265,8 +265,8 @@ def test_send_multisig_1(session: Session):
# sign with third key # sign with third key
inp1.address_n[2] = H_(3) inp1.address_n[2] = H_(3)
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -282,6 +282,7 @@ def test_attack_change_input_address(session: Session):
# Simulates an attack where the user is coerced into unknowingly # Simulates an attack where the user is coerced into unknowingly
# transferring funds from one account to another one of their accounts, # transferring funds from one account to another one of their accounts,
# potentially resulting in privacy issues. # potentially resulting in privacy issues.
client = session.client
inp1 = messages.TxInputType( inp1 = messages.TxInputType(
address_n=parse_path("m/49h/1h/0h/1/0"), address_n=parse_path("m/49h/1h/0h/1/0"),
@ -303,8 +304,8 @@ def test_attack_change_input_address(session: Session):
) )
# Test if the transaction can be signed normally. # Test if the transaction can be signed normally.
with session: with client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -349,8 +350,8 @@ def test_attack_change_input_address(session: Session):
return msg return msg
# Now run the attack, must trigger the exception # Now run the attack, must trigger the exception
with session: with client:
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
with pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure):
btc.sign_tx( btc.sign_tx(
session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET
@ -360,6 +361,7 @@ def test_attack_change_input_address(session: Session):
def test_attack_mixed_inputs(session: Session): def test_attack_mixed_inputs(session: Session):
TRUE_AMOUNT = 123_456_789 TRUE_AMOUNT = 123_456_789
FAKE_AMOUNT = 120_000_000 FAKE_AMOUNT = 120_000_000
client = session.client
inp1 = messages.TxInputType( inp1 = messages.TxInputType(
address_n=parse_path("m/44h/1h/0h/0/0"), address_n=parse_path("m/44h/1h/0h/0/0"),
@ -421,10 +423,10 @@ def test_attack_mixed_inputs(session: Session):
# T1 asks for first input for witness again # T1 asks for first input for witness again
expected_responses.insert(-2, request_input(0)) expected_responses.insert(-2, request_input(0))
with session: with client:
# Sign unmodified transaction. # Sign unmodified transaction.
# "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT # "Fee over threshold" warning is displayed - fee is the whole TRUE_AMOUNT
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
btc.sign_tx( btc.sign_tx(
session, session,
"Testnet", "Testnet",
@ -446,8 +448,8 @@ def test_attack_mixed_inputs(session: Session):
expected_responses[:4] + expected_responses[5:16] + [messages.Failure()] expected_responses[:4] + expected_responses[5:16] + [messages.Failure()]
) )
with pytest.raises(TrezorFailure) as e, session: with pytest.raises(TrezorFailure) as e, client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
btc.sign_tx( btc.sign_tx(
session, session,
"Testnet", "Testnet",

View File

@ -82,8 +82,8 @@ def test_send_p2sh(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
amount=123_456_789 - 11_000 - 12_300_000, amount=123_456_789 - 11_000 - 12_300_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -137,8 +137,8 @@ def test_send_p2sh_change(session: Session):
script_type=messages.OutputScriptType.PAYTOP2SHWITNESS, script_type=messages.OutputScriptType.PAYTOP2SHWITNESS,
amount=123_456_789 - 11_000 - 12_300_000, amount=123_456_789 - 11_000 - 12_300_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -190,8 +190,8 @@ def test_send_native(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
amount=100_000 - 40_000 - 10_000, amount=100_000 - 40_000 - 10_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -244,7 +244,7 @@ def test_send_to_taproot(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
amount=10_000 - 7_000 - 200, amount=10_000 - 7_000 - 200,
) )
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1, out2], prev_txes=TX_API_TESTNET
) )
@ -277,8 +277,8 @@ def test_send_native_change(session: Session):
script_type=messages.OutputScriptType.PAYTOWITNESS, script_type=messages.OutputScriptType.PAYTOWITNESS,
amount=100_000 - 40_000 - 10_000, amount=100_000 - 40_000 - 10_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -344,8 +344,8 @@ def test_send_both(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -449,8 +449,8 @@ def test_send_multisig_1(session: Session):
request_finished(), request_finished(),
] ]
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -460,8 +460,8 @@ def test_send_multisig_1(session: Session):
# sign with third key # sign with third key
inp1.address_n[2] = H_(3) inp1.address_n[2] = H_(3)
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -526,8 +526,8 @@ def test_send_multisig_2(session: Session):
request_finished(), request_finished(),
] ]
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -537,8 +537,8 @@ def test_send_multisig_2(session: Session):
# sign with first key # sign with first key
inp1.address_n[2] = H_(1) inp1.address_n[2] = H_(1)
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
) )
@ -611,10 +611,10 @@ def test_send_multisig_3_change(session: Session):
request_finished(), request_finished(),
] ]
with session, session.client as client: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
@ -626,10 +626,10 @@ def test_send_multisig_3_change(session: Session):
inp1.address_n[2] = H_(3) inp1.address_n[2] = H_(3)
out1.address_n[2] = H_(3) out1.address_n[2] = H_(3)
with session, session.client as client: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
@ -703,10 +703,10 @@ def test_send_multisig_4_change(session: Session):
request_finished(), request_finished(),
] ]
with session, session.client as client: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
signatures, _ = btc.sign_tx( signatures, _ = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
@ -718,10 +718,10 @@ def test_send_multisig_4_change(session: Session):
inp1.address_n[2] = H_(3) inp1.address_n[2] = H_(3)
out1.address_n[2] = H_(3) out1.address_n[2] = H_(3)
with session, session.client as client: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
if is_core(session): if is_core(session):
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET session, "Testnet", [inp1], [out1], prev_txes=TX_API_TESTNET
@ -788,8 +788,8 @@ def test_multisig_mismatch_inputs_single(session: Session):
amount=100_000 + 100_000 - 50_000 - 10_000, amount=100_000 + 100_000 - 50_000 - 10_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),

View File

@ -79,8 +79,8 @@ def test_send_p2tr(session: Session, chunkify: bool):
amount=4_450, amount=4_450,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -133,8 +133,8 @@ def test_send_two_with_change(session: Session):
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
amount=6_800 + 13_000 - 200 - 15_000, amount=6_800 + 13_000 - 200 - 15_000,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -222,8 +222,8 @@ def test_send_mixed(session: Session):
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
# process inputs # process inputs
request_input(0), request_input(0),
@ -353,9 +353,9 @@ def test_attack_script_type(session: Session):
return msg return msg
with session: with session.client as client:
session.set_filter(messages.TxAck, attack_processor) client.set_filter(messages.TxAck, attack_processor)
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -406,8 +406,8 @@ def test_send_invalid_address(session: Session, address: str):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session, pytest.raises(TrezorFailure): with session.client as client, pytest.raises(TrezorFailure):
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),

View File

@ -41,7 +41,7 @@ def test_message_long_legacy(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_message_long_core(session: Session): def test_message_long_core(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSignVerifyMessageLong(client, verify=True) IF = InputFlowSignVerifyMessageLong(session.client, verify=True)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
ret = btc.verify_message( ret = btc.verify_message(
session, session,

View File

@ -75,7 +75,7 @@ def test_v3_not_supported(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session, pytest.raises(TrezorFailure, match="DataError"): with session.client, pytest.raises(TrezorFailure, match="DataError"):
btc.sign_tx( btc.sign_tx(
session, session,
"Zcash Testnet", "Zcash Testnet",
@ -106,8 +106,8 @@ def test_one_one_fee_sapling(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -210,7 +210,7 @@ def test_spend_old_versions(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client:
_, serialized_tx = btc.sign_tx( _, serialized_tx = btc.sign_tx(
session, session,
"Zcash Testnet", "Zcash Testnet",
@ -259,8 +259,8 @@ def test_external_presigned(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),

View File

@ -95,9 +95,9 @@ def test_cardano_get_address(session: Session, chunkify: bool, parameters, resul
"cardano/get_public_key.derivations.json", "cardano/get_public_key.derivations.json",
) )
def test_cardano_get_public_key(session: Session, parameters, result): def test_cardano_get_public_key(session: Session, parameters, result):
with session: with session.client as client:
IF = InputFlowShowXpubQRCode(session.client, passphrase_request_expected=False) IF = InputFlowShowXpubQRCode(session.client, passphrase_request_expected=False)
session.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# session.init_device(new_session=True, derive_cardano=True) # session.init_device(new_session=True, derive_cardano=True)
derivation_type = CardanoDerivationType.__members__[ derivation_type = CardanoDerivationType.__members__[

View File

@ -63,7 +63,7 @@ def test_cardano_sign_tx(session: Session, parameters, result):
response = call_sign_tx( response = call_sign_tx(
session, session,
parameters, parameters,
input_flow=lambda client: InputFlowConfirmAllWarnings(client).get(), input_flow=lambda client: InputFlowConfirmAllWarnings(session.client).get(),
) )
assert response == _transform_expected_result(result) assert response == _transform_expected_result(result)
@ -124,8 +124,8 @@ def call_sign_tx(session: Session, parameters, input_flow=None, chunkify: bool =
with session.client as client: with session.client as client:
if input_flow is not None: if input_flow is not None:
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow(client)) client.set_input_flow(input_flow(session.client))
return cardano.sign_tx( return cardano.sign_tx(
session=session, session=session,

View File

@ -30,7 +30,7 @@ from ...input_flows import InputFlowShowXpubQRCode
@pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_eos_get_public_key(session: Session): def test_eos_get_public_key(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowShowXpubQRCode(client) IF = InputFlowShowXpubQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
public_key = get_public_key( public_key = get_public_key(
session, parse_path("m/44h/194h/0h/0/0"), show_display=True session, parse_path("m/44h/194h/0h/0/0"), show_display=True

View File

@ -60,7 +60,7 @@ def test_eos_signtx_transfer_token(session: Session, chunkify: bool):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID, chunkify=chunkify)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -93,7 +93,7 @@ def test_eos_signtx_buyram(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -126,7 +126,7 @@ def test_eos_signtx_buyrambytes(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -155,7 +155,7 @@ def test_eos_signtx_sellram(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -190,7 +190,7 @@ def test_eos_signtx_delegate(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -224,7 +224,7 @@ def test_eos_signtx_undelegate(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -253,7 +253,7 @@ def test_eos_signtx_refund(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -287,7 +287,7 @@ def test_eos_signtx_linkauth(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -320,7 +320,7 @@ def test_eos_signtx_unlinkauth(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -376,7 +376,7 @@ def test_eos_signtx_updateauth(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -405,7 +405,7 @@ def test_eos_signtx_deleteauth(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -468,7 +468,7 @@ def test_eos_signtx_vote(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -497,7 +497,7 @@ def test_eos_signtx_vote_proxy(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -526,7 +526,7 @@ def test_eos_signtx_unknown(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -602,7 +602,7 @@ def test_eos_signtx_newaccount(session: Session):
"transaction_extensions": [], "transaction_extensions": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (
@ -638,7 +638,7 @@ def test_eos_signtx_setcontract(session: Session):
"context_free_data": [], "context_free_data": [],
} }
with session: with session.client:
resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID) resp = eos.sign_tx(session, ADDRESS_N, transaction, CHAIN_ID)
assert isinstance(resp, EosSignedTx) assert isinstance(resp, EosSignedTx)
assert ( assert (

View File

@ -123,9 +123,9 @@ def test_external_token(session: Session) -> None:
def test_external_chain_without_token(session: Session) -> None: def test_external_chain_without_token(session: Session) -> None:
with session, session.client as client: with session.client as client:
if not client.debug.legacy_debug: if not session.client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) client.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
# when using an external chains, unknown tokens are allowed # when using an external chains, unknown tokens are allowed
network = common.encode_network(chain_id=66666, slip44=60) network = common.encode_network(chain_id=66666, slip44=60)
params = DEFAULT_ERC20_PARAMS.copy() params = DEFAULT_ERC20_PARAMS.copy()
@ -145,9 +145,9 @@ def test_external_chain_token_ok(session: Session) -> None:
def test_external_chain_token_mismatch(session: Session) -> None: def test_external_chain_token_mismatch(session: Session) -> None:
with session, session.client as client: with session.client as client:
if not client.debug.legacy_debug: if not session.client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) client.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
# when providing external defs, we explicitly allow, but not use, tokens # when providing external defs, we explicitly allow, but not use, tokens
# from other chains # from other chains
network = common.encode_network(chain_id=66666, slip44=60) network = common.encode_network(chain_id=66666, slip44=60)

View File

@ -38,7 +38,7 @@ def test_getaddress(session: Session, parameters, result):
@parametrize_using_common_fixtures("ethereum/getaddress.json") @parametrize_using_common_fixtures("ethereum/getaddress.json")
def test_getaddress_chunkify_details(session: Session, parameters, result): def test_getaddress_chunkify_details(session: Session, parameters, result):
with session.client as client: with session.client as client:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address_n = parse_path(parameters["path"]) address_n = parse_path(parameters["path"])
assert ( assert (

View File

@ -29,7 +29,7 @@ pytestmark = [pytest.mark.altcoin, pytest.mark.ethereum]
@pytest.mark.models("core") @pytest.mark.models("core")
@parametrize_using_common_fixtures("ethereum/sign_typed_data.json") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json")
def test_ethereum_sign_typed_data(session: Session, parameters, result): def test_ethereum_sign_typed_data(session: Session, parameters, result):
with session: with session.client:
address_n = parse_path(parameters["path"]) address_n = parse_path(parameters["path"])
ret = ethereum.sign_typed_data( ret = ethereum.sign_typed_data(
session, session,
@ -44,7 +44,7 @@ def test_ethereum_sign_typed_data(session: Session, parameters, result):
@pytest.mark.models("legacy") @pytest.mark.models("legacy")
@parametrize_using_common_fixtures("ethereum/sign_typed_data.json") @parametrize_using_common_fixtures("ethereum/sign_typed_data.json")
def test_ethereum_sign_typed_data_blind(session: Session, parameters, result): def test_ethereum_sign_typed_data_blind(session: Session, parameters, result):
with session: with session.client:
address_n = parse_path(parameters["path"]) address_n = parse_path(parameters["path"])
ret = ethereum.sign_typed_data_hash( ret = ethereum.sign_typed_data_hash(
session, session,
@ -112,8 +112,8 @@ def test_ethereum_sign_typed_data_show_more_button(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_ethereum_sign_typed_data_cancel(session: Session): def test_ethereum_sign_typed_data_cancel(session: Session):
with session.client as client, pytest.raises(exceptions.Cancelled): with session.client as client, pytest.raises(exceptions.Cancelled):
client.watch_layout() session.client.watch_layout()
IF = InputFlowEIP712Cancel(client) IF = InputFlowEIP712Cancel(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
ethereum.sign_typed_data( ethereum.sign_typed_data(
session, session,

View File

@ -37,7 +37,7 @@ def test_signmessage(session: Session, parameters, result):
assert res.signature.hex() == result["sig"] assert res.signature.hex() == result["sig"]
else: else:
with session.client as client: with session.client as client:
IF = InputFlowSignVerifyMessageLong(client) IF = InputFlowSignVerifyMessageLong(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
res = ethereum.sign_message( res = ethereum.sign_message(
session, parse_path(parameters["path"]), parameters["msg"] session, parse_path(parameters["path"]), parameters["msg"]
@ -58,7 +58,7 @@ def test_verify(session: Session, parameters, result):
assert res is True assert res is True
else: else:
with session.client as client: with session.client as client:
IF = InputFlowSignVerifyMessageLong(client, verify=True) IF = InputFlowSignVerifyMessageLong(session.client, verify=True)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
res = ethereum.verify_message( res = ethereum.verify_message(
session, session,

View File

@ -147,9 +147,9 @@ def test_signtx_go_back_from_summary(session: Session):
def test_signtx_eip1559( def test_signtx_eip1559(
session: Session, chunkify: bool, parameters: dict, result: dict session: Session, chunkify: bool, parameters: dict, result: dict
): ):
with session, session.client as client: with session.client as client:
if not client.debug.legacy_debug: if not session.client.debug.legacy_debug:
client.set_input_flow(InputFlowConfirmAllWarnings(client).get()) client.set_input_flow(InputFlowConfirmAllWarnings(session.client).get())
sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559(
session, session,
n=parse_path(parameters["path"]), n=parse_path(parameters["path"]),
@ -218,8 +218,8 @@ def test_data_streaming(session: Session):
"""Only verifying the expected responses, the signatures are """Only verifying the expected responses, the signatures are
checked in vectorized function above. checked in vectorized function above.
""" """
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx),
messages.ButtonRequest(code=messages.ButtonRequestType.SignTx), messages.ButtonRequest(code=messages.ButtonRequestType.SignTx),
@ -266,7 +266,7 @@ def test_data_streaming(session: Session):
def test_signtx_eip1559_access_list(session: Session): def test_signtx_eip1559_access_list(session: Session):
with session: with session.client:
sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559(
session, session,
@ -305,7 +305,7 @@ def test_signtx_eip1559_access_list(session: Session):
def test_signtx_eip1559_access_list_larger(session: Session): def test_signtx_eip1559_access_list_larger(session: Session):
with session: with session.client:
sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559(
session, session,
@ -438,6 +438,8 @@ HEXDATA = "0123456789abcd000023456789abcd010003456789abcd020000456789abcd0300000
) )
@pytest.mark.models("core") @pytest.mark.models("core")
def test_signtx_data_pagination(session: Session, flow): def test_signtx_data_pagination(session: Session, flow):
client = session.client
def _sign_tx_call(): def _sign_tx_call():
ethereum.sign_tx( ethereum.sign_tx(
session, session,
@ -452,15 +454,15 @@ def test_signtx_data_pagination(session: Session, flow):
data=bytes.fromhex(HEXDATA), data=bytes.fromhex(HEXDATA),
) )
with session, session.client as client: with client:
client.watch_layout() client.watch_layout()
client.set_input_flow(flow(client)) client.set_input_flow(flow(session.client))
_sign_tx_call() _sign_tx_call()
if flow is not input_flow_data_scroll_down: if flow is not input_flow_data_scroll_down:
with session, session.client as client, pytest.raises(exceptions.Cancelled): with client, pytest.raises(exceptions.Cancelled):
client.watch_layout() client.watch_layout()
client.set_input_flow(flow(client, cancel=True)) client.set_input_flow(flow(session.client, cancel=True))
_sign_tx_call() _sign_tx_call()
@ -500,7 +502,7 @@ def test_signtx_staking_bad_inputs(session: Session, parameters: dict, result: d
@pytest.mark.models("core") @pytest.mark.models("core")
@parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json") @parametrize_using_common_fixtures("ethereum/sign_tx_staking_eip1559.json")
def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict): def test_signtx_staking_eip1559(session: Session, parameters: dict, result: dict):
with session: with session.client:
sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559( sig_v, sig_r, sig_s = ethereum.sign_tx_eip1559(
session, session,
n=parse_path(parameters["path"]), n=parse_path(parameters["path"]),

View File

@ -33,7 +33,7 @@ def test_encrypt(client: Client):
client.debug.press_yes() client.debug.press_yes()
session = client.get_session() session = client.get_session()
with client, session: with session.client as client:
client.set_input_flow(input_flow()) client.set_input_flow(input_flow())
misc.encrypt_keyvalue( misc.encrypt_keyvalue(
session, session,

View File

@ -41,8 +41,8 @@ def entropy(data):
@pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS) @pytest.mark.parametrize("entropy_length", ENTROPY_LENGTHS)
def test_entropy(session: Session, entropy_length): def test_entropy(session: Session, entropy_length):
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy] [m.ButtonRequest(code=m.ButtonRequestType.ProtectCall), m.Entropy]
) )
ent = misc.get_entropy(session, entropy_length) ent = misc.get_entropy(session, entropy_length)

View File

@ -57,7 +57,7 @@ def test_monero_getaddress_chunkify_details(
session: Session, path: str, expected_address: bytes session: Session, path: str, expected_address: bytes
): ):
with session.client as client: with session.client as client:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address = monero.get_address( address = monero.get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True

View File

@ -32,7 +32,7 @@ pytestmark = [
# assertion data from T1 # assertion data from T1
def test_nem_signtx_importance_transfer(session: Session): def test_nem_signtx_importance_transfer(session: Session):
with session: with session.client:
tx = nem.sign_tx( tx = nem.sign_tx(
session, session,
parse_path("m/44h/1h/0h/0h/0h"), parse_path("m/44h/1h/0h/0h/0h"),

View File

@ -33,8 +33,8 @@ pytestmark = [
# assertion data from T1 # assertion data from T1
@pytest.mark.parametrize("chunkify", (True, False)) @pytest.mark.parametrize("chunkify", (True, False))
def test_nem_signtx_simple(session: Session, chunkify: bool): def test_nem_signtx_simple(session: Session, chunkify: bool):
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
# Confirm transfer and network fee # Confirm transfer and network fee
messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput),
@ -83,8 +83,8 @@ def test_nem_signtx_simple(session: Session, chunkify: bool):
@pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_nem_signtx_encrypted_payload(session: Session): def test_nem_signtx_encrypted_payload(session: Session):
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
# Confirm transfer and network fee # Confirm transfer and network fee
messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput), messages.ButtonRequest(code=messages.ButtonRequestType.ConfirmOutput),

View File

@ -52,8 +52,8 @@ def do_recover_legacy(session: Session, mnemonic: list[str]):
def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False): def do_recover_core(session: Session, mnemonic: list[str], mismatch: bool = False):
with session.client as client: with session.client as client:
client.watch_layout() session.client.watch_layout()
IF = InputFlowBip39RecoveryDryRun(client, mnemonic, mismatch=mismatch) IF = InputFlowBip39RecoveryDryRun(session.client, mnemonic, mismatch=mismatch)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
return device.recover(session, type=messages.RecoveryType.DryRun) return device.recover(session, type=messages.RecoveryType.DryRun)
@ -87,8 +87,8 @@ def test_invalid_seed_t1(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_invalid_seed_core(session: Session): def test_invalid_seed_core(session: Session):
with session, session.client as client: with session.client as client:
client.watch_layout() session.client.watch_layout()
IF = InputFlowBip39RecoveryDryRunInvalid(session) IF = InputFlowBip39RecoveryDryRunInvalid(session)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):

View File

@ -29,7 +29,7 @@ pytestmark = pytest.mark.models("core")
@pytest.mark.uninitialized_session @pytest.mark.uninitialized_session
def test_tt_pin_passphrase(session: Session): def test_tt_pin_passphrase(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" "), pin="654") IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "), pin="654")
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
@ -50,7 +50,7 @@ def test_tt_pin_passphrase(session: Session):
@pytest.mark.uninitialized_session @pytest.mark.uninitialized_session
def test_tt_nopin_nopassphrase(session: Session): def test_tt_nopin_nopassphrase(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowBip39Recovery(client, MNEMONIC12.split(" ")) IF = InputFlowBip39Recovery(session.client, MNEMONIC12.split(" "))
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(
session, session,

View File

@ -49,7 +49,9 @@ def _test_secret(
session: Session, shares: list[str], secret: str, click_info: bool = False session: Session, shares: list[str], secret: str, click_info: bool = False
): ):
with session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedRecovery(client, shares, click_info=click_info) IF = InputFlowSlip39AdvancedRecovery(
session.client, shares, click_info=click_info
)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
@ -90,7 +92,7 @@ def test_extra_share_entered(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort(session: Session): def test_abort(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedRecoveryAbort(client) IF = InputFlowSlip39AdvancedRecoveryAbort(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@ -102,7 +104,7 @@ def test_abort(session: Session):
def test_noabort(session: Session): def test_noabort(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedRecoveryNoAbort( IF = InputFlowSlip39AdvancedRecoveryNoAbort(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
) )
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@ -118,7 +120,7 @@ def test_same_share(session: Session):
# second share is first 4 words of first # second share is first 4 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4] second_share = MNEMONIC_SLIP39_ADVANCED_20[1].split(" ")[:4]
with session, session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered( IF = InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(
session, first_share, second_share session, first_share, second_share
) )
@ -134,7 +136,7 @@ def test_group_threshold_reached(session: Session):
# second share is first 3 words of first # second share is first 3 words of first
second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3] second_share = MNEMONIC_SLIP39_ADVANCED_20[0].split(" ")[:3]
with session, session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedRecoveryThresholdReached( IF = InputFlowSlip39AdvancedRecoveryThresholdReached(
session, first_share, second_share session, first_share, second_share
) )

View File

@ -42,7 +42,7 @@ EXTRA_GROUP_SHARE = [
def test_2of3_dryrun(session: Session): def test_2of3_dryrun(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedRecoveryDryRun( IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20 session.client, EXTRA_GROUP_SHARE + MNEMONIC_SLIP39_ADVANCED_20
) )
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(
@ -61,7 +61,7 @@ def test_2of3_invalid_seed_dryrun(session: Session):
TrezorFailure, match=r"The seed does not match the one in the device" TrezorFailure, match=r"The seed does not match the one in the device"
): ):
IF = InputFlowSlip39AdvancedRecoveryDryRun( IF = InputFlowSlip39AdvancedRecoveryDryRun(
client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True session.client, INVALID_SHARES_SLIP39_ADVANCED_20, mismatch=True
) )
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(

View File

@ -74,7 +74,7 @@ def test_secret(
session: Session, shares: list[str], secret: str, backup_type: messages.BackupType session: Session, shares: list[str], secret: str, backup_type: messages.BackupType
): ):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecovery(client, shares) IF = InputFlowSlip39BasicRecovery(session.client, shares)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@ -91,7 +91,7 @@ def test_secret(
def test_recover_with_pin_passphrase(session: Session): def test_recover_with_pin_passphrase(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecovery( IF = InputFlowSlip39BasicRecovery(
client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654" session.client, MNEMONIC_SLIP39_BASIC_20_3of6, pin="654"
) )
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(
@ -110,7 +110,7 @@ def test_recover_with_pin_passphrase(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_abort(session: Session): def test_abort(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryAbort(client) IF = InputFlowSlip39BasicRecoveryAbort(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@ -124,7 +124,7 @@ def test_abort(session: Session):
def test_abort_on_number_of_words(session: Session): def test_abort_on_number_of_words(session: Session):
# on Caesar, test_abort actually aborts on the # of words selection # on Caesar, test_abort actually aborts on the # of words selection
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(client) IF = InputFlowSlip39BasicRecoveryAbortOnNumberOfWords(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
@ -136,7 +136,7 @@ def test_abort_on_number_of_words(session: Session):
def test_abort_between_shares(session: Session): def test_abort_between_shares(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryAbortBetweenShares( IF = InputFlowSlip39BasicRecoveryAbortBetweenShares(
client, MNEMONIC_SLIP39_BASIC_20_3of6 session.client, MNEMONIC_SLIP39_BASIC_20_3of6
) )
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
@ -149,7 +149,9 @@ def test_abort_between_shares(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_noabort(session: Session): def test_noabort(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryNoAbort(client, MNEMONIC_SLIP39_BASIC_20_3of6) IF = InputFlowSlip39BasicRecoveryNoAbort(
session.client, MNEMONIC_SLIP39_BASIC_20_3of6
)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
session.refresh_features() session.refresh_features()
@ -158,7 +160,7 @@ def test_noabort(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_invalid_mnemonic_first_share(session: Session): def test_invalid_mnemonic_first_share(session: Session):
with session, session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session) IF = InputFlowSlip39BasicRecoveryInvalidFirstShare(session)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
@ -169,7 +171,7 @@ def test_invalid_mnemonic_first_share(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_invalid_mnemonic_second_share(session: Session): def test_invalid_mnemonic_second_share(session: Session):
with session, session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryInvalidSecondShare( IF = InputFlowSlip39BasicRecoveryInvalidSecondShare(
session, MNEMONIC_SLIP39_BASIC_20_3of6 session, MNEMONIC_SLIP39_BASIC_20_3of6
) )
@ -184,7 +186,7 @@ def test_invalid_mnemonic_second_share(session: Session):
@pytest.mark.parametrize("nth_word", range(3)) @pytest.mark.parametrize("nth_word", range(3))
def test_wrong_nth_word(session: Session, nth_word: int): def test_wrong_nth_word(session: Session, nth_word: int):
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
with session, session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word) IF = InputFlowSlip39BasicRecoveryWrongNthWord(session, share, nth_word)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
@ -194,7 +196,7 @@ def test_wrong_nth_word(session: Session, nth_word: int):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_same_share(session: Session): def test_same_share(session: Session):
share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ") share = MNEMONIC_SLIP39_BASIC_20_3of6[0].split(" ")
with session, session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoverySameShare(session, share) IF = InputFlowSlip39BasicRecoverySameShare(session, share)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
@ -204,7 +206,7 @@ def test_same_share(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
def test_1of1(session: Session): def test_1of1(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecovery(client, MNEMONIC_SLIP39_BASIC_20_1of1) IF = InputFlowSlip39BasicRecovery(session.client, MNEMONIC_SLIP39_BASIC_20_1of1)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(
session, session,

View File

@ -39,7 +39,7 @@ INVALID_SHARES_20_2of3 = [
@pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2]) @pytest.mark.setup_client(mnemonic=SHARES_20_2of3[0:2])
def test_2of3_dryrun(session: Session): def test_2of3_dryrun(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecoveryDryRun(client, SHARES_20_2of3[1:3]) IF = InputFlowSlip39BasicRecoveryDryRun(session.client, SHARES_20_2of3[1:3])
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(
session, session,
@ -57,7 +57,7 @@ def test_2of3_invalid_seed_dryrun(session: Session):
TrezorFailure, match=r"The seed does not match the one in the device" TrezorFailure, match=r"The seed does not match the one in the device"
): ):
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, INVALID_SHARES_20_2of3, mismatch=True session.client, INVALID_SHARES_20_2of3, mismatch=True
) )
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover( device.recover(

View File

@ -78,7 +78,7 @@ VECTORS = [
def test_skip_backup_msg(session: Session, backup_type, backup_flow): def test_skip_backup_msg(session: Session, backup_type, backup_flow):
assert session.features.initialized is False assert session.features.initialized is False
with session: with session.client:
device.setup( device.setup(
session, session,
skip_backup=True, skip_backup=True,
@ -116,7 +116,7 @@ def test_skip_backup_msg(session: Session, backup_type, backup_flow):
def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow): def test_skip_backup_manual(session: Session, backup_type: BackupType, backup_flow):
assert session.features.initialized is False assert session.features.initialized is False
with session, session.client as client: with session.client as client:
IF = InputFlowResetSkipBackup(client) IF = InputFlowResetSkipBackup(client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.setup( device.setup(

View File

@ -36,7 +36,7 @@ pytestmark = pytest.mark.models("core")
def reset_device(session: Session, strength: int): def reset_device(session: Session, strength: int):
debug = session.client.debug debug = session.client.debug
with session.client as client: with session.client as client:
IF = InputFlowBip39ResetBackup(client) IF = InputFlowBip39ResetBackup(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
@ -92,7 +92,7 @@ def test_reset_device_pin(session: Session):
strength = 256 # 24 words strength = 256 # 24 words
with session.client as client: with session.client as client:
IF = InputFlowBip39ResetPIN(client) IF = InputFlowBip39ResetPIN(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# PIN, passphrase, display random # PIN, passphrase, display random
@ -130,7 +130,7 @@ def test_reset_entropy_check(session: Session):
strength = 128 # 12 words strength = 128 # 12 words
with session.client as client: with session.client as client:
IF = InputFlowBip39ResetBackup(client) IF = InputFlowBip39ResetBackup(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# No PIN, no passphrase # No PIN, no passphrase
@ -146,7 +146,7 @@ def test_reset_entropy_check(session: Session):
) )
# Generate the mnemonic locally. # Generate the mnemonic locally.
internal_entropy = client.debug.state().reset_entropy internal_entropy = session.client.debug.state().reset_entropy
assert internal_entropy is not None assert internal_entropy is not None
entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) entropy = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY)
expected_mnemonic = Mnemonic("english").to_mnemonic(entropy) expected_mnemonic = Mnemonic("english").to_mnemonic(entropy)
@ -177,7 +177,7 @@ def test_reset_failed_check(session: Session):
strength = 256 # 24 words strength = 256 # 24 words
with session.client as client: with session.client as client:
IF = InputFlowBip39ResetFailedCheck(client) IF = InputFlowBip39ResetFailedCheck(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# PIN, passphrase, display random # PIN, passphrase, display random
@ -263,9 +263,9 @@ def test_already_initialized(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@pytest.mark.uninitialized_session @pytest.mark.uninitialized_session
def test_entropy_check(session: Session): def test_entropy_check(session: Session):
with session: with session.client as client:
delizia = session.client.debug.layout_type is LayoutType.Delizia delizia = session.client.debug.layout_type is LayoutType.Delizia
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(name="setup_device"), messages.ButtonRequest(name="setup_device"),
(delizia, messages.ButtonRequest(name="confirm_setup_device")), (delizia, messages.ButtonRequest(name="confirm_setup_device")),
@ -300,9 +300,9 @@ def test_entropy_check(session: Session):
@pytest.mark.setup_client(uninitialized=True) @pytest.mark.setup_client(uninitialized=True)
@pytest.mark.uninitialized_session @pytest.mark.uninitialized_session
def test_no_entropy_check(session: Session): def test_no_entropy_check(session: Session):
with session: with session.client as client:
delizia = session.client.debug.layout_type is LayoutType.Delizia delizia = session.client.debug.layout_type is LayoutType.Delizia
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(name="setup_device"), messages.ButtonRequest(name="setup_device"),
(delizia, messages.ButtonRequest(name="confirm_setup_device")), (delizia, messages.ButtonRequest(name="confirm_setup_device")),

View File

@ -48,7 +48,7 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str: def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> str:
with session.client as client: with session.client as client:
IF = InputFlowBip39ResetBackup(client) IF = InputFlowBip39ResetBackup(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
@ -78,9 +78,9 @@ def reset(session: Session, strength: int = 128, skip_backup: bool = False) -> s
def recover(session: Session, mnemonic: str): def recover(session: Session, mnemonic: str):
words = mnemonic.split(" ") words = mnemonic.split(" ")
with session.client as client: with session.client as client:
IF = InputFlowBip39Recovery(client, words) IF = InputFlowBip39Recovery(session.client, words)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
client.watch_layout() session.client.watch_layout()
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")
# Workflow successfully ended # Workflow successfully ended

View File

@ -69,7 +69,7 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128) -> list[str]: def reset(session: Session, strength: int = 128) -> list[str]:
with session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedResetRecovery(client, False) IF = InputFlowSlip39AdvancedResetRecovery(session.client, False)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random

View File

@ -59,7 +59,7 @@ def test_reset_recovery(client: Client):
def reset(session: Session, strength: int = 128) -> list[str]: def reset(session: Session, strength: int = 128) -> list[str]:
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicResetRecovery(client) IF = InputFlowSlip39BasicResetRecovery(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
@ -88,7 +88,7 @@ def reset(session: Session, strength: int = 128) -> list[str]:
def recover(session: Session, shares: t.Sequence[str]): def recover(session: Session, shares: t.Sequence[str]):
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicRecovery(client, shares) IF = InputFlowSlip39BasicRecovery(session.client, shares)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.recover(session, pin_protection=False, label="label") device.recover(session, pin_protection=False, label="label")

View File

@ -34,10 +34,10 @@ def test_reset_device_slip39_advanced(client: Client):
strength = 128 strength = 128
member_threshold = 3 member_threshold = 3
with client: session = client.get_seedless_session()
with session.client as client:
IF = InputFlowSlip39AdvancedResetRecovery(client, False) IF = InputFlowSlip39AdvancedResetRecovery(client, False)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
session = client.get_seedless_session()
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
device.setup( device.setup(
session, session,

View File

@ -35,7 +35,7 @@ def reset_device(session: Session, strength: int):
member_threshold = 3 member_threshold = 3
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicResetRecovery(client) IF = InputFlowSlip39BasicResetRecovery(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# No PIN, no passphrase, don't display random # No PIN, no passphrase, don't display random
@ -90,7 +90,7 @@ def test_reset_entropy_check(session: Session):
strength = 128 # 20 words strength = 128 # 20 words
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicResetRecovery(client) IF = InputFlowSlip39BasicResetRecovery(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# No PIN, no passphrase. # No PIN, no passphrase.

View File

@ -53,7 +53,7 @@ def test_ripple_get_address_chunkify_details(
session: Session, path: str, expected_address: str session: Session, path: str, expected_address: str
): ):
with session.client as client: with session.client as client:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address = get_address( address = get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True

View File

@ -48,7 +48,7 @@ def test_solana_sign_tx(session: Session, parameters, result):
serialized_tx = _serialize_tx(parameters["construct"]) serialized_tx = _serialize_tx(parameters["construct"])
with session.client as client: with session.client as client:
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
actual_result = sign_tx( actual_result = sign_tx(
session, session,

View File

@ -123,7 +123,7 @@ def test_get_address(session: Session, parameters, result):
@parametrize_using_common_fixtures("stellar/get_address.json") @parametrize_using_common_fixtures("stellar/get_address.json")
def test_get_address_chunkify_details(session: Session, parameters, result): def test_get_address_chunkify_details(session: Session, parameters, result):
with session.client as client: with session.client as client:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address_n = parse_path(parameters["path"]) address_n = parse_path(parameters["path"])
address = stellar.get_address( address = stellar.get_address(

View File

@ -38,9 +38,9 @@ def pin_request(session: Session):
def set_autolock_delay(session: Session, delay): def set_autolock_delay(session: Session, delay):
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses( client.set_expected_responses(
[ [
pin_request(session), pin_request(session),
messages.ButtonRequest, messages.ButtonRequest,
@ -52,18 +52,19 @@ def set_autolock_delay(session: Session, delay):
def test_apply_auto_lock_delay(session: Session): def test_apply_auto_lock_delay(session: Session):
client = session.client
set_autolock_delay(session, 10 * 1000) set_autolock_delay(session, 10 * 1000)
time.sleep(0.1) # sleep less than auto-lock delay time.sleep(0.1) # sleep less than auto-lock delay
with session: with client:
# No PIN protection is required. # No PIN protection is required.
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)
time.sleep(10.5) # sleep more than auto-lock delay time.sleep(10.5) # sleep more than auto-lock delay
with session, session.client as client: with client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses([pin_request(session), messages.Address]) client.set_expected_responses([pin_request(session), messages.Address])
get_test_address(session) get_test_address(session)
@ -85,7 +86,7 @@ def test_apply_auto_lock_delay_valid(session: Session, seconds):
def test_autolock_default_value(session: Session): def test_autolock_default_value(session: Session):
assert session.features.auto_lock_delay_ms is None assert session.features.auto_lock_delay_ms is None
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
device.apply_settings(session, label="pls unlock") device.apply_settings(session, label="pls unlock")
session.refresh_features() session.refresh_features()
@ -98,9 +99,9 @@ def test_autolock_default_value(session: Session):
) )
def test_apply_auto_lock_delay_out_of_range(session: Session, seconds): def test_apply_auto_lock_delay_out_of_range(session: Session, seconds):
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses( client.set_expected_responses(
[ [
pin_request(session), pin_request(session),
messages.Failure(code=messages.FailureType.ProcessError), messages.Failure(code=messages.FailureType.ProcessError),

View File

@ -40,8 +40,8 @@ def test_cancel_message_via_cancel(session: Session, message):
yield yield
session.cancel() session.cancel()
with session, session.client as client, pytest.raises(Cancelled): with session.client as client, pytest.raises(Cancelled):
session.set_expected_responses([m.ButtonRequest(), m.Failure()]) client.set_expected_responses([m.ButtonRequest(), m.Failure()])
client.set_input_flow(input_flow) client.set_input_flow(input_flow)
session.call(message) session.call(message)

View File

@ -79,7 +79,7 @@ def _check_ping_screen_texts(session: Session, title: str, right_button: str) ->
if session.model in (models.T2T1, models.T3T1): if session.model in (models.T2T1, models.T3T1):
right_button = "-" right_button = "-"
with session, session.client as client: with session.client as client:
client.watch_layout(True) client.watch_layout(True)
client.set_input_flow(ping_input_flow(session, title, right_button)) client.set_input_flow(ping_input_flow(session, title, right_button))
ping = session.call(messages.Ping(message="ahoj!", button_protection=True)) ping = session.call(messages.Ping(message="ahoj!", button_protection=True))
@ -93,7 +93,7 @@ def test_error_too_long(session: Session):
max_length = MAX_DATA_LENGTH[session.model] max_length = MAX_DATA_LENGTH[session.model]
with pytest.raises( with pytest.raises(
exceptions.TrezorFailure, match="Translations too long" exceptions.TrezorFailure, match="Translations too long"
), session: ), session.client:
bad_data = (max_length + 1) * b"a" bad_data = (max_length + 1) * b"a"
device.change_language(session, language_data=bad_data) device.change_language(session, language_data=bad_data)
assert session.features.language == "en-US" assert session.features.language == "en-US"
@ -104,7 +104,9 @@ def test_error_invalid_data_length(session: Session):
assert session.features.language == "en-US" assert session.features.language == "en-US"
# Invalid data length # Invalid data length
# Sending more data than advertised in the header # Sending more data than advertised in the header
with pytest.raises(exceptions.TrezorFailure, match="Invalid data length"), session: with pytest.raises(
exceptions.TrezorFailure, match="Invalid data length"
), session.client:
good_data = build_and_sign_blob("cs", session) good_data = build_and_sign_blob("cs", session)
bad_data = good_data + b"abcd" bad_data = good_data + b"abcd"
device.change_language(session, language_data=bad_data) device.change_language(session, language_data=bad_data)
@ -118,7 +120,7 @@ def test_error_invalid_header_magic(session: Session):
# Does not match the expected magic # Does not match the expected magic
with pytest.raises( with pytest.raises(
exceptions.TrezorFailure, match="Invalid translations data" exceptions.TrezorFailure, match="Invalid translations data"
), session: ), session.client:
good_data = build_and_sign_blob("cs", session) good_data = build_and_sign_blob("cs", session)
bad_data = 4 * b"a" + good_data[4:] bad_data = 4 * b"a" + good_data[4:]
device.change_language(session, language_data=bad_data) device.change_language(session, language_data=bad_data)
@ -132,7 +134,7 @@ def test_error_invalid_data_hash(session: Session):
# Changing the data after their hash has been calculated # Changing the data after their hash has been calculated
with pytest.raises( with pytest.raises(
exceptions.TrezorFailure, match="Translation data verification failed" exceptions.TrezorFailure, match="Translation data verification failed"
), session: ), session.client:
good_data = build_and_sign_blob("cs", session) good_data = build_and_sign_blob("cs", session)
bad_data = good_data[:-8] + 8 * b"a" bad_data = good_data[:-8] + 8 * b"a"
device.change_language( device.change_language(
@ -149,7 +151,7 @@ def test_error_version_mismatch(session: Session):
# Change the version to one not matching the current device # Change the version to one not matching the current device
with pytest.raises( with pytest.raises(
exceptions.TrezorFailure, match="Translations version mismatch" exceptions.TrezorFailure, match="Translations version mismatch"
), session: ), session.client:
blob = prepare_blob("cs", session.model, (3, 5, 4, 0)) blob = prepare_blob("cs", session.model, (3, 5, 4, 0))
device.change_language( device.change_language(
session, session,
@ -165,7 +167,7 @@ def test_error_invalid_signature(session: Session):
# Changing the data in the signature section # Changing the data in the signature section
with pytest.raises( with pytest.raises(
exceptions.TrezorFailure, match="Invalid translations data" exceptions.TrezorFailure, match="Invalid translations data"
), session: ), session.client:
blob = prepare_blob("cs", session.model, session.version) blob = prepare_blob("cs", session.model, session.version)
blob.proof = translations.Proof( blob.proof = translations.Proof(
merkle_proof=[], merkle_proof=[],
@ -274,7 +276,7 @@ def test_reject_update(session: Session):
yield yield
session.client.debug.press_no() session.client.debug.press_no()
with pytest.raises(exceptions.Cancelled), session, session.client as client: with pytest.raises(exceptions.Cancelled), session.client as client:
client.set_input_flow(input_flow_reject) client.set_input_flow(input_flow_reject)
device.change_language(session, language_data) device.change_language(session, language_data)
@ -311,8 +313,8 @@ def _maybe_confirm_set_language(
else: else:
expected_responses = expected_responses_silent expected_responses = expected_responses_silent
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
device.change_language(session, language_data, show_display=show_display) device.change_language(session, language_data, show_display=show_display)
assert session.features.language is not None assert session.features.language is not None
assert session.features.language[:2] == lang assert session.features.language[:2] == lang
@ -320,9 +322,9 @@ def _maybe_confirm_set_language(
# explicitly handle the cases when expected_responses are correct for # explicitly handle the cases when expected_responses are correct for
# change_language but incorrect for selected is_displayed mode (otherwise the # change_language but incorrect for selected is_displayed mode (otherwise the
# user would get an unhelpful generic expected_responses mismatch) # user would get an unhelpful generic expected_responses mismatch)
if is_displayed and session.actual_responses == expected_responses_silent: if is_displayed and client.actual_responses == expected_responses_silent:
raise AssertionError("Change should have been visible but was silent") raise AssertionError("Change should have been visible but was silent")
if not is_displayed and session.actual_responses == expected_responses_confirm: if not is_displayed and client.actual_responses == expected_responses_confirm:
raise AssertionError("Change should have been silent but was visible") raise AssertionError("Change should have been silent but was visible")
# if the expected_responses do not match either, the generic error message will # if the expected_responses do not match either, the generic error message will
# be raised by the session context manager # be raised by the session context manager

View File

@ -20,6 +20,7 @@ import pytest
from trezorlib import btc, device, exceptions, messages, misc, models from trezorlib import btc, device, exceptions, messages, misc, models
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ..input_flows import InputFlowConfirmAllWarnings from ..input_flows import InputFlowConfirmAllWarnings
@ -50,19 +51,19 @@ T1_HOMESCREEN = b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x
TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{<m\xc9\x02j\xeb\xea_\xddy\xfe~\xf14:\xfc\xe2I\x00\xf8\xff\x13\xff\xfa\xc6>\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p<D?\x0e0\xd0)ufMK\x9d\x84\xbf\x95\x02\x15\x04\xaf\x9b\xd7|\x9f\xf5\xc2\x19D\xe1\xe8pC=\\\xb54\xff\xfd<\xfc\x8b\x83\x19\x9aZ\x99J\x9d\xa2oP6\xb2=\xe0\xe5=Z0\x7f\xb6\xe9\xb1\x98\n\xcc \xdb\x9f\xb6\xf4\xc2\x82Z:\t\xf2\xcd\x88\xe3\x8a0\n\x13\xdd\xf9;\xdbtr\xf4kj\xa6\x90\x9d\x97\x92+#\xf4;t\x1e6\x80\xcd)\xfe\xe1\xabdD(V\xf5\xcc\xcf\xbeY\xd8\xddX\x08*\xc5\xd6\xa1\xa2\xae\x89s]\xafj\x9b\x07\x8d\xcf\xb5\n\x162\xb7\xb0\x02p\xc0.{\xbf\xd6\xc7|\x8a\x8c\xc9g\xa8DO\x85\xf6<E\x05Ek\x8c\xbfU\x13bz\xcf\xd0\x07\xcd^\x0f\x9b\x951\xa1vb\x17u:\xd2X\x91/\x0f\x9a\xae\x16T\x81\xb6\x8e\xdc,\xb0\xa1=\x11af%^\xec\x95\x83\xa9\xb8+\xd0i\xe0+#%\x02\xfd2\x84\xf3\xde\x0c\x88\x8c\x80\xf7\xc2H\xaa#\x80m\xf4\x1e\xd4#\x04J\r\xb6\xf83s\x8c\x0e\x0bx\xabw\xbe\x90\x94\x90:C\x88\x9bR`B\xc02\x1a\x08\xca-M9\xac\xa3TP\xb1\x83\xf2\x8aT\xe9\xc0c9(\xe5d\xd1\xac\xfd\x83\xf3\xb4C\x95\x04doh\xd7M\xed \xd0\x90\xc9|\x8a\x1fX\x1f\x0eI\x12\x8e&\xc3\x91NM-\x02\xeckp\x1a/\x19\x9d\xf2\xb4$\x0eG:\xbe\xb2~\x10(,\x1cd\x07\xbb]n^F@\x173'\xc63\xdf!u\xf4F\xa9\xc3\x96E!e\xc2Iv\xe8zQH=v\x89\x9a\x04a^\x10\x06\x01\x04!2\xa6\x1b\xba\x111/\xfa\xfa\x9c\x15 @\xf6\xb9{&\xcc\x84\xfa\xd6\x81\x90\xd4\x92\xcf\x89/\xc8\x80\xadP\xa3\xf4Xa\x1f\x04A\x07\xe6N\xd2oEZ\xc9\xa6(!\x8e#|\x0e\xfbq\xce\xe6\x8b-;\x06_\x04n\xdc\x8d^\x05s\xd2\xa8\x0f\xfa/\xfa\xf8\xe1x\n\xa3\xf701i7\x0c \x87\xec#\x80\x9c^X\x02\x01C\xc7\x85\x83\x9dS\xf5\x07" TR_HOMESCREEN = b"TOIG\x80\x00@\x00\x0c\x04\x00\x00\xa5RY\x96\xdc0\x08\xe4\x06\xdc\xff\x96\xdc\x80\xa8\x16\x90z\xd2y\xf9\x18{<m\xc9\x02j\xeb\xea_\xddy\xfe~\xf14:\xfc\xe2I\x00\xf8\xff\x13\xff\xfa\xc6>\xc0\xf1\xe5\xc9y\x0f\x95\x7f;C\xfe\xd0\xe1K\xefS\x96o\xf9\xb739\x1a\n\xc7\xde\x89\xff\x11\xd8=\xd5\xcf\xb1\x9f\xf7U\xf2\xa3spx\xb0&t\xe4\xaf3x\xcaT\xec\xe50k\xb4\xe8\nl\x16\xbf`'\xf3\xa7Z\x8d-\x98h\x1c\x03\x07\xf0\xcf\xf0\x8aD\x13\xec\x1f@y\x9e\xd8\xa3\xc6\x84F*\x1dx\x02U\x00\x10\xd3\x8cF\xbb\x97y\x18J\xa5T\x18x\x1c\x02\xc6\x90\xfd\xdc\x89\x1a\x94\xb3\xeb\x01\xdc\x9f2\x8c/\xe9/\x8c$\xc6\x9c\x1e\xf8C\x8f@\x17Q\x1d\x11F\x02g\xe4A \xebO\xad\xc6\xe3F\xa7\x8b\xf830R\x82\x0b\x8e\x16\x1dL,\x14\xce\x057tht^\xfe\x00\x9e\x86\xc2\x86\xa3b~^Bl\x18\x1f\xb9+w\x11\x14\xceO\xe9\xb6W\xd8\x85\xbeX\x17\xc2\x13,M`y\xd1~\xa3/\xcd0\xed6\xda\xf5b\x15\xb5\x18\x0f_\xf6\xe2\xdc\x8d\x8ez\xdd\xd5\r^O\x9e\xb6|\xc4e\x0f\x1f\xff0k\xd4\xb8\n\x12{\x8d\x8a>\x0b5\xa2o\xf2jZ\xe5\xee\xdc\x14\xd1\xbd\xd5\xad\x95\xbe\x8c\t\x8f\xb9\xde\xc4\xa551,#`\x94'\x1b\xe7\xd53u\x8fq\xbd4v>3\x8f\xcc\x1d\xbcV>\x90^\xb3L\xc3\xde0]\x05\xec\x83\xd0\x07\xd2(\xbb\xcf+\xd0\xc7ru\xecn\x14k-\xc0|\xd2\x0e\xe8\xe08\xa8<\xdaQ+{\xad\x01\x02#\x16\x12+\xc8\xe0P\x06\xedD7\xae\xd0\xa4\x97\x84\xe32\xca;]\xd04x:\x94`\xbe\xca\x89\xe2\xcb\xc5L\x03\xac|\xe7\xd5\x1f\xe3\x08_\xee!\x04\xd2\xef\x00\xd8\xea\x91p)\xed^#\xb1\xa78eJ\x00F*\xc7\xf1\x0c\x1a\x04\xf5l\xcc\xfc\xa4\x83,c\x1e\xb1>\xc5q\x8b\xe6Y9\xc7\x07\xfa\xcf\xf9\x15\x8a\xdd\x11\x1f\x98\x82\xbe>\xbe+u#g]aC\\\x1bC\xb1\xe8P\xce2\xd6\xb6r\x12\x1c*\xd3\x92\x9d9\xf9cB\x82\xf9S.\xc2B\xe7\x9d\xcf\xdb\xf3\xfd#\xfd\x94x9p<D?\x0e0\xd0)ufMK\x9d\x84\xbf\x95\x02\x15\x04\xaf\x9b\xd7|\x9f\xf5\xc2\x19D\xe1\xe8pC=\\\xb54\xff\xfd<\xfc\x8b\x83\x19\x9aZ\x99J\x9d\xa2oP6\xb2=\xe0\xe5=Z0\x7f\xb6\xe9\xb1\x98\n\xcc \xdb\x9f\xb6\xf4\xc2\x82Z:\t\xf2\xcd\x88\xe3\x8a0\n\x13\xdd\xf9;\xdbtr\xf4kj\xa6\x90\x9d\x97\x92+#\xf4;t\x1e6\x80\xcd)\xfe\xe1\xabdD(V\xf5\xcc\xcf\xbeY\xd8\xddX\x08*\xc5\xd6\xa1\xa2\xae\x89s]\xafj\x9b\x07\x8d\xcf\xb5\n\x162\xb7\xb0\x02p\xc0.{\xbf\xd6\xc7|\x8a\x8c\xc9g\xa8DO\x85\xf6<E\x05Ek\x8c\xbfU\x13bz\xcf\xd0\x07\xcd^\x0f\x9b\x951\xa1vb\x17u:\xd2X\x91/\x0f\x9a\xae\x16T\x81\xb6\x8e\xdc,\xb0\xa1=\x11af%^\xec\x95\x83\xa9\xb8+\xd0i\xe0+#%\x02\xfd2\x84\xf3\xde\x0c\x88\x8c\x80\xf7\xc2H\xaa#\x80m\xf4\x1e\xd4#\x04J\r\xb6\xf83s\x8c\x0e\x0bx\xabw\xbe\x90\x94\x90:C\x88\x9bR`B\xc02\x1a\x08\xca-M9\xac\xa3TP\xb1\x83\xf2\x8aT\xe9\xc0c9(\xe5d\xd1\xac\xfd\x83\xf3\xb4C\x95\x04doh\xd7M\xed \xd0\x90\xc9|\x8a\x1fX\x1f\x0eI\x12\x8e&\xc3\x91NM-\x02\xeckp\x1a/\x19\x9d\xf2\xb4$\x0eG:\xbe\xb2~\x10(,\x1cd\x07\xbb]n^F@\x173'\xc63\xdf!u\xf4F\xa9\xc3\x96E!e\xc2Iv\xe8zQH=v\x89\x9a\x04a^\x10\x06\x01\x04!2\xa6\x1b\xba\x111/\xfa\xfa\x9c\x15 @\xf6\xb9{&\xcc\x84\xfa\xd6\x81\x90\xd4\x92\xcf\x89/\xc8\x80\xadP\xa3\xf4Xa\x1f\x04A\x07\xe6N\xd2oEZ\xc9\xa6(!\x8e#|\x0e\xfbq\xce\xe6\x8b-;\x06_\x04n\xdc\x8d^\x05s\xd2\xa8\x0f\xfa/\xfa\xf8\xe1x\n\xa3\xf701i7\x0c \x87\xec#\x80\x9c^X\x02\x01C\xc7\x85\x83\x9dS\xf5\x07"
def _set_expected_responses(session: Session): def _set_expected_responses(client: Client):
session.client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
if session.model is models.T1B1: if client.model is models.T1B1:
session.set_expected_responses(EXPECTED_RESPONSES_PIN_T1) client.set_expected_responses(EXPECTED_RESPONSES_PIN_T1)
else: else:
session.set_expected_responses(EXPECTED_RESPONSES_PIN_TT) client.set_expected_responses(EXPECTED_RESPONSES_PIN_TT)
def test_apply_settings(session: Session): def test_apply_settings(session: Session):
assert session.features.label == "test" assert session.features.label == "test"
with session: with session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, label="new label") device.apply_settings(session, label="new label")
assert session.features.label == "new label" assert session.features.label == "new label"
@ -72,8 +73,8 @@ def test_apply_settings(session: Session):
def test_apply_settings_rotation(session: Session): def test_apply_settings_rotation(session: Session):
assert session.features.display_rotation is None assert session.features.display_rotation is None
with session: with session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, display_rotation=messages.DisplayRotation.West) device.apply_settings(session, display_rotation=messages.DisplayRotation.West)
assert session.features.display_rotation == messages.DisplayRotation.West assert session.features.display_rotation == messages.DisplayRotation.West
@ -81,20 +82,21 @@ def test_apply_settings_rotation(session: Session):
@pytest.mark.setup_client(pin=PIN4, passphrase=False) @pytest.mark.setup_client(pin=PIN4, passphrase=False)
def test_apply_settings_passphrase(session: Session): def test_apply_settings_passphrase(session: Session):
with session: client = session.client
_set_expected_responses(session) with client:
_set_expected_responses(client)
device.apply_settings(session, use_passphrase=True) device.apply_settings(session, use_passphrase=True)
assert session.features.passphrase_protection is True assert session.features.passphrase_protection is True
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, use_passphrase=False) device.apply_settings(session, use_passphrase=False)
assert session.features.passphrase_protection is False assert session.features.passphrase_protection is False
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, use_passphrase=True) device.apply_settings(session, use_passphrase=True)
assert session.features.passphrase_protection is True assert session.features.passphrase_protection is True
@ -103,32 +105,33 @@ def test_apply_settings_passphrase(session: Session):
@pytest.mark.setup_client(passphrase=False) @pytest.mark.setup_client(passphrase=False)
@pytest.mark.models("core") @pytest.mark.models("core")
def test_apply_settings_passphrase_on_device(session: Session): def test_apply_settings_passphrase_on_device(session: Session):
client = session.client
# enable passphrase # enable passphrase
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, use_passphrase=True) device.apply_settings(session, use_passphrase=True)
assert session.features.passphrase_protection is True assert session.features.passphrase_protection is True
# enable force on device # enable force on device
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, passphrase_always_on_device=True) device.apply_settings(session, passphrase_always_on_device=True)
assert session.features.passphrase_protection is True assert session.features.passphrase_protection is True
assert session.features.passphrase_always_on_device is True assert session.features.passphrase_always_on_device is True
# turning off the passphrase should also clear the always_on_device setting # turning off the passphrase should also clear the always_on_device setting
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, use_passphrase=False) device.apply_settings(session, use_passphrase=False)
assert session.features.passphrase_protection is False assert session.features.passphrase_protection is False
assert session.features.passphrase_always_on_device is False assert session.features.passphrase_always_on_device is False
# and turning it back on does not modify always_on_device # and turning it back on does not modify always_on_device
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, use_passphrase=True) device.apply_settings(session, use_passphrase=True)
assert session.features.passphrase_protection is True assert session.features.passphrase_protection is True
@ -137,35 +140,36 @@ def test_apply_settings_passphrase_on_device(session: Session):
@pytest.mark.models("safe3") @pytest.mark.models("safe3")
def test_apply_homescreen_tr_toif_good(session: Session): def test_apply_homescreen_tr_toif_good(session: Session):
with session: with session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=TR_HOMESCREEN) device.apply_settings(session, homescreen=TR_HOMESCREEN)
# Revert to default settings # Revert to default settings
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, homescreen=b"") device.apply_settings(session, homescreen=b"")
@pytest.mark.models("safe3") @pytest.mark.models("safe3")
@pytest.mark.setup_client(pin=None) # so that "PIN NOT SET" is shown in the header @pytest.mark.setup_client(pin=None) # so that "PIN NOT SET" is shown in the header
def test_apply_homescreen_tr_toif_with_notification(session: Session): def test_apply_homescreen_tr_toif_with_notification(session: Session):
with session: with session.client as client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, homescreen=TR_HOMESCREEN) device.apply_settings(session, homescreen=TR_HOMESCREEN)
@pytest.mark.models("safe3") @pytest.mark.models("safe3")
def test_apply_homescreen_tr_toif_with_long_label(session: Session): def test_apply_homescreen_tr_toif_with_long_label(session: Session):
with session: client = session.client
_set_expected_responses(session) with session.client:
_set_expected_responses(client)
device.apply_settings(session, homescreen=TR_HOMESCREEN) device.apply_settings(session, homescreen=TR_HOMESCREEN)
# Showing longer label # Showing longer label
with session: with session.client:
device.apply_settings(session, label="My long label") device.apply_settings(session, label="My long label")
# Showing label that will not fit on the line # Showing label that will not fit on the line
with session: with session.client:
device.apply_settings(session, label="My even longer label") device.apply_settings(session, label="My even longer label")
@ -173,8 +177,8 @@ def test_apply_homescreen_tr_toif_with_long_label(session: Session):
def test_apply_homescreen_tr_toif_wrong_size(session: Session): def test_apply_homescreen_tr_toif_wrong_size(session: Session):
# 64x64 img # 64x64 img
img = b"TOIG@\x00@\x009\x02\x00\x00}R\xdb\x81$A\x08\"\x03\xf3\xcf\xd2\x0c<\x01-{\xefc\xe6\xd5\xbbU\xa2\x08T\xd6\xcfw\xf4\xe7\xc7\xb7X\xf1\xe3\x1bl\xf0\xf7\x1b\xf8\x1f\xcf\xe7}\xe1\x83\xcf|>\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9:<j+\xcb\xf3_\xc7\xd6^<\xce\xc1\xb8!\xec\x8f/\xb1\xc1\x8f\xbd\xcc\x06\x90\x0e\x98)[\xdb\x15\x99\xaf\xf2~\x8e\xd0\xdb\xcd\xfd\x90\x12\xb6\xdd\xc3\xdd|\x96$\x01P\x86H\xbc\xc0}\xa2\x08\xe5\x82\x06\xd2\xeb\x07[\r\xe4\xdeP\xf4\x86;\xa5\x14c\x12\xe3\xb16x\xad\xc7\x1d\x02\xef\x86<\xc6\x95\xd3/\xc4 \xa1\xf5V\xe2\t\xb2\x8a\xd6`\xf2\xcf\xb7\xd6\x07\xdf8X\xa7\x18\x03\x96\x82\xa4 \xeb.*kP\xceu\x9d~}H\xe9\xb8\x04<4\xff\xf8\xcf\xf6\xa0\xf2\xfcM\xe3/?k\xff\x18\x1d\xb1\xee\xc5\xf5\x1f\x01\x14\x03;\x1bU\x1f~\xcf\xb3\xf7w\xe5\nMfd/\xb93\x9fq\x9bQ\xb7'\xbfvq\x1d\xce\r\xbaDo\x90\xbc\xc5:?;\x84y\x8a\x1e\xad\xe9\xb7\x14\x10~\x9b@\xf8\x82\xdc\x89\xe7\xf0\xe0k4o\x9a\xa0\xc4\xb9\xba\xc56\x01i\x85EO'e6\xb7\x15\xb4G\x05\xe1\xe7%\xd3&\x93\x91\xc9CTQ\xeb\xcc\xd0\xd7E9\xa9JK\xcc\x00\x95(\xdc.\xd2#7:Yo}y_*\x1a\xae6)\x97\x9d\xc0\x80vl\x02\\M\xfe\xc9sW\xa8\xfbD\x99\xb8\xb0:\xbc\x80\xfd\xef\xd3\x94\xbe\x18j9z\x12S\xa1\xec$\x1c\xe3\xd1\xd0\xf4\xdd\xbfI\xf1rBj\x0f\x1cz\x1d\xf7\xa5tR\xb3\xfc\xa4\xd0\xfah\xc3Mj\xbe\x14r\x9d\x84z\xd2\x7f\x13\xb4w\xce\xa0\xaeW\xa4\x18\x0b\xe4\x8f\xe6\xc3\xbeQ\x93\xb0L<J\xe3g9\xb5W#f\xd1\x0b\x96|\xd6z1;\x85\x7f\xe3\xe6[\x02A\xdc\xa4\x02\x1b\x91\x88\x7f" img = b"TOIG@\x00@\x009\x02\x00\x00}R\xdb\x81$A\x08\"\x03\xf3\xcf\xd2\x0c<\x01-{\xefc\xe6\xd5\xbbU\xa2\x08T\xd6\xcfw\xf4\xe7\xc7\xb7X\xf1\xe3\x1bl\xf0\xf7\x1b\xf8\x1f\xcf\xe7}\xe1\x83\xcf|>\x8d%\x14\xa5\xb3\xe9p5\xa1;~4:\xcd\xe0&\x11\x1d\xe9\xf6\xa1\x1fw\xf54\x95eWx\xda\xd0u\x91\x86\xb8\xbc\xdf\xdc\x008f\x15\xc6\xf6\x7f\xf0T\xb8\xc1\xa3\xc5_A\xc0G\x930\xe7\xdc=\xd5\xa7\xc1\xbcI\x16\xb8s\x9c&\xaa\x06\xc1}\x8b\x19\x9d'c\xc3\xe3^\xc3m\xb6n\xb0(\x16\xf6\xdeg\xb3\x96:i\xe5\x9c\x02\x93\x9fF\x9f-\xa7\"w\xf3X\x9f\x87\x08\x84\"v,\xab!9:<j+\xcb\xf3_\xc7\xd6^<\xce\xc1\xb8!\xec\x8f/\xb1\xc1\x8f\xbd\xcc\x06\x90\x0e\x98)[\xdb\x15\x99\xaf\xf2~\x8e\xd0\xdb\xcd\xfd\x90\x12\xb6\xdd\xc3\xdd|\x96$\x01P\x86H\xbc\xc0}\xa2\x08\xe5\x82\x06\xd2\xeb\x07[\r\xe4\xdeP\xf4\x86;\xa5\x14c\x12\xe3\xb16x\xad\xc7\x1d\x02\xef\x86<\xc6\x95\xd3/\xc4 \xa1\xf5V\xe2\t\xb2\x8a\xd6`\xf2\xcf\xb7\xd6\x07\xdf8X\xa7\x18\x03\x96\x82\xa4 \xeb.*kP\xceu\x9d~}H\xe9\xb8\x04<4\xff\xf8\xcf\xf6\xa0\xf2\xfcM\xe3/?k\xff\x18\x1d\xb1\xee\xc5\xf5\x1f\x01\x14\x03;\x1bU\x1f~\xcf\xb3\xf7w\xe5\nMfd/\xb93\x9fq\x9bQ\xb7'\xbfvq\x1d\xce\r\xbaDo\x90\xbc\xc5:?;\x84y\x8a\x1e\xad\xe9\xb7\x14\x10~\x9b@\xf8\x82\xdc\x89\xe7\xf0\xe0k4o\x9a\xa0\xc4\xb9\xba\xc56\x01i\x85EO'e6\xb7\x15\xb4G\x05\xe1\xe7%\xd3&\x93\x91\xc9CTQ\xeb\xcc\xd0\xd7E9\xa9JK\xcc\x00\x95(\xdc.\xd2#7:Yo}y_*\x1a\xae6)\x97\x9d\xc0\x80vl\x02\\M\xfe\xc9sW\xa8\xfbD\x99\xb8\xb0:\xbc\x80\xfd\xef\xd3\x94\xbe\x18j9z\x12S\xa1\xec$\x1c\xe3\xd1\xd0\xf4\xdd\xbfI\xf1rBj\x0f\x1cz\x1d\xf7\xa5tR\xb3\xfc\xa4\xd0\xfah\xc3Mj\xbe\x14r\x9d\x84z\xd2\x7f\x13\xb4w\xce\xa0\xaeW\xa4\x18\x0b\xe4\x8f\xe6\xc3\xbeQ\x93\xb0L<J\xe3g9\xb5W#f\xd1\x0b\x96|\xd6z1;\x85\x7f\xe3\xe6[\x02A\xdc\xa4\x02\x1b\x91\x88\x7f"
with pytest.raises(exceptions.TrezorFailure), session: with pytest.raises(exceptions.TrezorFailure), session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=img) device.apply_settings(session, homescreen=img)
@ -182,15 +186,15 @@ def test_apply_homescreen_tr_toif_wrong_size(session: Session):
def test_apply_homescreen_tr_upload_jpeg_fail(session: Session): def test_apply_homescreen_tr_upload_jpeg_fail(session: Session):
with open(HERE / "test_bg.jpg", "rb") as f: with open(HERE / "test_bg.jpg", "rb") as f:
img = f.read() img = f.read()
with pytest.raises(exceptions.TrezorFailure), session: with pytest.raises(exceptions.TrezorFailure), session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=img) device.apply_settings(session, homescreen=img)
@pytest.mark.models("safe3") @pytest.mark.models("safe3")
def test_apply_homescreen_tr_upload_t1_fail(session: Session): def test_apply_homescreen_tr_upload_t1_fail(session: Session):
with pytest.raises(exceptions.TrezorFailure), session: with pytest.raises(exceptions.TrezorFailure), session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=T1_HOMESCREEN) device.apply_settings(session, homescreen=T1_HOMESCREEN)
@ -198,8 +202,8 @@ def test_apply_homescreen_tr_upload_t1_fail(session: Session):
def test_apply_homescreen_toif(session: Session): def test_apply_homescreen_toif(session: Session):
img = b"TOIf\x90\x00\x90\x00~\x00\x00\x00\xed\xd2\xcb\r\x83@\x10D\xc1^.\xde#!\xac31\x99\x10\x8aC%\x14~\x16\x92Y9\x02WI3\x01<\xf5cI2d\x1es(\xe1[\xdbn\xba\xca\xe8s7\xa4\xd5\xd4\xb3\x13\xbdw\xf6:\xf3\xd1\xe7%\xc7]\xdd_\xb3\x9e\x9f\x9e\x9fN\xed\xaaE\xef\xdc\xcf$D\xa7\xa4X\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0OV" img = b"TOIf\x90\x00\x90\x00~\x00\x00\x00\xed\xd2\xcb\r\x83@\x10D\xc1^.\xde#!\xac31\x99\x10\x8aC%\x14~\x16\x92Y9\x02WI3\x01<\xf5cI2d\x1es(\xe1[\xdbn\xba\xca\xe8s7\xa4\xd5\xd4\xb3\x13\xbdw\xf6:\xf3\xd1\xe7%\xc7]\xdd_\xb3\x9e\x9f\x9e\x9fN\xed\xaaE\xef\xdc\xcf$D\xa7\xa4X\r\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0OV"
with pytest.raises(exceptions.TrezorFailure), session: with pytest.raises(exceptions.TrezorFailure), session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=img) device.apply_settings(session, homescreen=img)
@ -208,11 +212,11 @@ def test_apply_homescreen_jpeg(session: Session):
with open(HERE / "test_bg.jpg", "rb") as f: with open(HERE / "test_bg.jpg", "rb") as f:
img = f.read() img = f.read()
# raise Exception("FAILS FOR SOME REASON ") # raise Exception("FAILS FOR SOME REASON ")
with session: with session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=img) device.apply_settings(session, homescreen=img)
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, homescreen=b"") device.apply_settings(session, homescreen=b"")
@ -267,8 +271,8 @@ def test_apply_homescreen_jpeg_progressive(session: Session):
b"\x00\x00\x00\x00\x90\xff\xda\x00\x08\x01\x01\x00\x01?\x10a?\xff\xd9" b"\x00\x00\x00\x00\x90\xff\xda\x00\x08\x01\x01\x00\x01?\x10a?\xff\xd9"
) )
with pytest.raises(exceptions.TrezorFailure), session: with pytest.raises(exceptions.TrezorFailure), session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=img) device.apply_settings(session, homescreen=img)
@ -313,76 +317,79 @@ def test_apply_homescreen_jpeg_wrong_size(session: Session):
b"\x00\x00\x1f\xff\xd9" b"\x00\x00\x1f\xff\xd9"
) )
with pytest.raises(exceptions.TrezorFailure), session: with pytest.raises(exceptions.TrezorFailure), session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=img) device.apply_settings(session, homescreen=img)
@pytest.mark.models("legacy") @pytest.mark.models("legacy")
def test_apply_homescreen(session: Session): def test_apply_homescreen(session: Session):
with session: with session.client as client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, homescreen=T1_HOMESCREEN) device.apply_settings(session, homescreen=T1_HOMESCREEN)
@pytest.mark.setup_client(pin=None) @pytest.mark.setup_client(pin=None)
def test_safety_checks(session: Session): def test_safety_checks(session: Session):
client = session.client
def get_bad_address(): def get_bad_address():
btc.get_address(session, "Bitcoin", parse_path("m/44h"), show_display=True) btc.get_address(session, "Bitcoin", parse_path("m/44h"), show_display=True)
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), session: with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client:
session.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
get_bad_address() get_bad_address()
if session.model is not models.T1B1: if session.model is not models.T1B1:
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings( device.apply_settings(
session, safety_checks=messages.SafetyCheckLevel.PromptAlways session, safety_checks=messages.SafetyCheckLevel.PromptAlways
) )
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways
with session, session.client as client: with client:
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address] [messages.ButtonRequest, messages.ButtonRequest, messages.Address]
) )
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
get_bad_address() get_bad_address()
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict) device.apply_settings(session, safety_checks=messages.SafetyCheckLevel.Strict)
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), session: with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client:
session.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
get_bad_address() get_bad_address()
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings( device.apply_settings(
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
) )
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily
with session, session.client as client: with client:
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address] [messages.ButtonRequest, messages.ButtonRequest, messages.Address]
) )
if session.model is not models.T1B1: if session.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
get_bad_address() get_bad_address()
@pytest.mark.models("core") @pytest.mark.models("core")
def test_experimental_features(session: Session): def test_experimental_features(session: Session):
client = session.client
def experimental_call(): def experimental_call():
misc.get_nonce(session) misc.get_nonce(session)
@ -390,38 +397,38 @@ def test_experimental_features(session: Session):
assert session.features.experimental_features is None assert session.features.experimental_features is None
# unlock # unlock
with session: with session.client:
_set_expected_responses(session) _set_expected_responses(client)
device.apply_settings(session, label="new label") device.apply_settings(session, label="new label")
assert not session.features.experimental_features assert not session.features.experimental_features
with pytest.raises(exceptions.TrezorFailure, match="DataError"), session: with pytest.raises(exceptions.TrezorFailure, match="DataError"), client:
session.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
experimental_call() experimental_call()
with session: with client:
session.set_expected_responses(EXPECTED_RESPONSES_EXPERIMENTAL_FEATURES) client.set_expected_responses(EXPECTED_RESPONSES_EXPERIMENTAL_FEATURES)
device.apply_settings(session, experimental_features=True) device.apply_settings(session, experimental_features=True)
assert session.features.experimental_features assert session.features.experimental_features
with session: with client:
session.set_expected_responses([messages.Nonce]) client.set_expected_responses([messages.Nonce])
experimental_call() experimental_call()
# relock and try again # relock and try again
session.lock() session.lock()
with session, session.client as client: with client:
client.use_pin_sequence([PIN4]) session.client.use_pin_sequence([PIN4])
session.set_expected_responses([messages.ButtonRequest, messages.Nonce]) client.set_expected_responses([messages.ButtonRequest, messages.Nonce])
experimental_call() experimental_call()
@pytest.mark.setup_client(pin=None) @pytest.mark.setup_client(pin=None)
def test_label_too_long(session: Session): def test_label_too_long(session: Session):
with pytest.raises(exceptions.TrezorFailure), session: with pytest.raises(exceptions.TrezorFailure), session.client as client:
session.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
device.apply_settings(session, label="A" * 33) device.apply_settings(session, label="A" * 33)

View File

@ -45,7 +45,7 @@ def test_backup_bip39(session: Session):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session.client as client:
IF = InputFlowBip39Backup(client) IF = InputFlowBip39Backup(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -72,7 +72,7 @@ def test_backup_slip39_basic(session: Session, click_info: bool):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session.client as client:
IF = InputFlowSlip39BasicBackup(client, click_info) IF = InputFlowSlip39BasicBackup(session.client, click_info)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -97,7 +97,8 @@ def test_backup_slip39_single(session: Session):
with session.client as client: with session.client as client:
IF = InputFlowBip39Backup( IF = InputFlowBip39Backup(
client, confirm_success=(client.layout_type is not LayoutType.Delizia) session.client,
confirm_success=(session.client.layout_type is not LayoutType.Delizia),
) )
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -127,7 +128,7 @@ def test_backup_slip39_advanced(session: Session, click_info: bool):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session.client as client:
IF = InputFlowSlip39AdvancedBackup(client, click_info) IF = InputFlowSlip39AdvancedBackup(session.client, click_info)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -158,7 +159,7 @@ def test_backup_slip39_custom(session: Session, share_threshold, share_count):
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
with session.client as client: with session.client as client:
IF = InputFlowSlip39CustomBackup(client, share_count) IF = InputFlowSlip39CustomBackup(session.client, share_count)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup( device.backup(
session, group_threshold=1, groups=[(share_threshold, share_count)] session, group_threshold=1, groups=[(share_threshold, share_count)]

View File

@ -35,7 +35,7 @@ pytestmark = pytest.mark.models("legacy")
def _set_wipe_code(session: Session, pin, wipe_code): def _set_wipe_code(session: Session, pin, wipe_code):
# Set/change wipe code. # Set/change wipe code.
with session.client as client, session: with session.client as client:
if session.features.pin_protection: if session.features.pin_protection:
pins = [pin, wipe_code, wipe_code] pins = [pin, wipe_code, wipe_code]
pin_matrices = [ pin_matrices = [
@ -51,7 +51,7 @@ def _set_wipe_code(session: Session, pin, wipe_code):
] ]
client.use_pin_sequence(pins) client.use_pin_sequence(pins)
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] + pin_matrices + [messages.Success] [messages.ButtonRequest()] + pin_matrices + [messages.Success]
) )
device.change_wipe_code(session) device.change_wipe_code(session)
@ -112,9 +112,9 @@ def test_set_wipe_code_mismatch(session: Session):
assert session.features.wipe_code_protection is False assert session.features.wipe_code_protection is False
# Let's set a new wipe code. # Let's set a new wipe code.
with session.client as client, session: with session.client as client:
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6]) client.use_pin_sequence([WIPE_CODE4, WIPE_CODE6])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),
messages.PinMatrixRequest(type=PinType.WipeCodeFirst), messages.PinMatrixRequest(type=PinType.WipeCodeFirst),
@ -136,9 +136,9 @@ def test_set_wipe_code_to_pin(session: Session):
assert session.features.wipe_code_protection is None assert session.features.wipe_code_protection is None
# Let's try setting the wipe code to the curent PIN value. # Let's try setting the wipe code to the curent PIN value.
with session.client as client, session: with session.client as client:
client.use_pin_sequence([PIN4, PIN4]) client.use_pin_sequence([PIN4, PIN4])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),
messages.PinMatrixRequest(type=PinType.Current), messages.PinMatrixRequest(type=PinType.Current),
@ -160,9 +160,9 @@ def test_set_pin_to_wipe_code(session: Session):
_set_wipe_code(session, None, WIPE_CODE4) _set_wipe_code(session, None, WIPE_CODE4)
# Try to set the PIN to the current wipe code value. # Try to set the PIN to the current wipe code value.
with session.client as client, session: with session.client as client:
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(), messages.ButtonRequest(),
messages.PinMatrixRequest(type=PinType.NewFirst), messages.PinMatrixRequest(type=PinType.NewFirst),

View File

@ -37,13 +37,13 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str):
assert session.features.wipe_code_protection is True assert session.features.wipe_code_protection is True
# Try to change the PIN to the current wipe code value. The operation should fail. # Try to change the PIN to the current wipe code value. The operation should fail.
with session, session.client as client, pytest.raises(TrezorFailure): with session.client as client, pytest.raises(TrezorFailure):
client.use_pin_sequence([pin, wipe_code, wipe_code]) client.use_pin_sequence([pin, wipe_code, wipe_code])
if session.client.layout_type is LayoutType.Caesar: if session.client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 5 br_count = 5
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * br_count [messages.ButtonRequest()] * br_count
+ [messages.Failure(code=messages.FailureType.PinInvalid)] + [messages.Failure(code=messages.FailureType.PinInvalid)]
) )
@ -51,7 +51,7 @@ def _check_wipe_code(session: Session, pin: str, wipe_code: str):
def _ensure_unlocked(session: Session, pin: str): def _ensure_unlocked(session: Session, pin: str):
with session, session.client as client: with session.client as client:
client.use_pin_sequence([pin]) client.use_pin_sequence([pin])
btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH) btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
@ -60,19 +60,20 @@ def _ensure_unlocked(session: Session, pin: str):
@pytest.mark.setup_client(pin=PIN4) @pytest.mark.setup_client(pin=PIN4)
def test_set_remove_wipe_code(session: Session): def test_set_remove_wipe_code(session: Session):
client = session.client
# Test set wipe code. # Test set wipe code.
assert session.features.wipe_code_protection is None assert session.features.wipe_code_protection is None
_ensure_unlocked(session, PIN4) _ensure_unlocked(session, PIN4)
assert session.features.wipe_code_protection is False assert session.features.wipe_code_protection is False
if session.client.layout_type is LayoutType.Caesar: if client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 5 br_count = 5
with session, session.client as client: with client:
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success] [messages.ButtonRequest()] * br_count + [messages.Success]
) )
client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX]) client.use_pin_sequence([PIN4, WIPE_CODE_MAX, WIPE_CODE_MAX])
@ -83,8 +84,8 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE_MAX) _check_wipe_code(session, PIN4, WIPE_CODE_MAX)
# Test change wipe code. # Test change wipe code.
with session, session.client as client: with client:
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success] [messages.ButtonRequest()] * br_count + [messages.Success]
) )
client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6]) client.use_pin_sequence([PIN4, WIPE_CODE6, WIPE_CODE6])
@ -95,8 +96,8 @@ def test_set_remove_wipe_code(session: Session):
_check_wipe_code(session, PIN4, WIPE_CODE6) _check_wipe_code(session, PIN4, WIPE_CODE6)
# Test remove wipe code. # Test remove wipe code.
with session, session.client as client: with client:
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * 3 + [messages.Success] [messages.ButtonRequest()] * 3 + [messages.Success]
) )
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
@ -107,8 +108,10 @@ def test_set_remove_wipe_code(session: Session):
def test_set_wipe_code_mismatch(session: Session): def test_set_wipe_code_mismatch(session: Session):
with session, session.client as client, pytest.raises(TrezorFailure): with session.client as client, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(client, WIPE_CODE4, WIPE_CODE6, what="wipe_code") IF = InputFlowNewCodeMismatch(
session.client, WIPE_CODE4, WIPE_CODE6, what="wipe_code"
)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.change_wipe_code(session) device.change_wipe_code(session)
@ -122,12 +125,12 @@ def test_set_wipe_code_mismatch(session: Session):
def test_set_wipe_code_to_pin(session: Session): def test_set_wipe_code_to_pin(session: Session):
_ensure_unlocked(session, PIN4) _ensure_unlocked(session, PIN4)
with session, session.client as client: with session.client as client:
if client.layout_type is LayoutType.Caesar: if client.layout_type is LayoutType.Caesar:
br_count = 8 br_count = 8
else: else:
br_count = 7 br_count = 7
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success], [messages.ButtonRequest()] * br_count + [messages.Success],
) )
client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4]) client.use_pin_sequence([PIN4, PIN4, WIPE_CODE4, WIPE_CODE4])
@ -139,25 +142,26 @@ def test_set_wipe_code_to_pin(session: Session):
def test_set_pin_to_wipe_code(session: Session): def test_set_pin_to_wipe_code(session: Session):
client = session.client
# Set wipe code. # Set wipe code.
with session, session.client as client: with client:
if client.layout_type is LayoutType.Caesar: if client.layout_type is LayoutType.Caesar:
br_count = 5 br_count = 5
else: else:
br_count = 4 br_count = 4
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * br_count + [messages.Success] [messages.ButtonRequest()] * br_count + [messages.Success]
) )
client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4]) client.use_pin_sequence([WIPE_CODE4, WIPE_CODE4])
device.change_wipe_code(session) device.change_wipe_code(session)
# Try to set the PIN to the current wipe code value. # Try to set the PIN to the current wipe code value.
with session, session.client as client, pytest.raises(TrezorFailure): with client, pytest.raises(TrezorFailure):
if client.layout_type is LayoutType.Caesar: if client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 4 br_count = 4
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest()] * br_count [messages.ButtonRequest()] * br_count
+ [messages.Failure(code=messages.FailureType.PinInvalid)] + [messages.Failure(code=messages.FailureType.PinInvalid)]
) )

View File

@ -33,16 +33,16 @@ pytestmark = pytest.mark.models("legacy")
def _check_pin(session: Session, pin): def _check_pin(session: Session, pin):
session.lock() session.lock()
with session, session.client as client: with session.client as client:
client.use_pin_sequence([pin]) client.use_pin_sequence([pin])
session.set_expected_responses([messages.PinMatrixRequest, messages.Address]) client.set_expected_responses([messages.PinMatrixRequest, messages.Address])
get_test_address(session) get_test_address(session)
def _check_no_pin(session: Session): def _check_no_pin(session: Session):
session.lock() session.lock()
with session: with session.client as client:
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)
@ -53,9 +53,9 @@ def test_set_pin(session: Session):
_check_no_pin(session) _check_no_pin(session)
# Let's set new PIN # Let's set new PIN
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN_MAX, PIN_MAX]) client.use_pin_sequence([PIN_MAX, PIN_MAX])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
messages.PinMatrixRequest, messages.PinMatrixRequest,
@ -78,9 +78,9 @@ def test_change_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's change PIN # Let's change PIN
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
messages.PinMatrixRequest, messages.PinMatrixRequest,
@ -104,9 +104,9 @@ def test_remove_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's remove PIN # Let's remove PIN
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
messages.PinMatrixRequest, messages.PinMatrixRequest,
@ -126,12 +126,10 @@ def test_set_mismatch(session: Session):
_check_no_pin(session) _check_no_pin(session)
# Let's set new PIN # Let's set new PIN
with session, session.client as client, pytest.raises( with session.client as client, pytest.raises(TrezorFailure, match="PIN mismatch"):
TrezorFailure, match="PIN mismatch"
):
# use different PINs for first and second attempt. This will fail. # use different PINs for first and second attempt. This will fail.
client.use_pin_sequence([PIN4, PIN_MAX]) client.use_pin_sequence([PIN4, PIN_MAX])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
messages.PinMatrixRequest, messages.PinMatrixRequest,
@ -152,11 +150,9 @@ def test_change_mismatch(session: Session):
assert session.features.pin_protection is True assert session.features.pin_protection is True
# Let's set new PIN # Let's set new PIN
with session, session.client as client, pytest.raises( with session.client as client, pytest.raises(TrezorFailure, match="PIN mismatch"):
TrezorFailure, match="PIN mismatch"
):
client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"]) client.use_pin_sequence([PIN4, PIN6, PIN6 + "3"])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
messages.PinMatrixRequest, messages.PinMatrixRequest,

View File

@ -37,13 +37,13 @@ pytestmark = pytest.mark.models("core")
def _check_pin(session: Session, pin: str): def _check_pin(session: Session, pin: str):
with session, session.client as client: with session.client as client:
client.ui.__init__(client.debug) client.ui.__init__(session.client.debug)
client.use_pin_sequence([pin, pin, pin, pin, pin, pin]) client.use_pin_sequence([pin, pin, pin, pin, pin, pin])
session.lock() session.lock()
assert session.features.pin_protection is True assert session.features.pin_protection is True
assert session.features.unlocked is False assert session.features.unlocked is False
session.set_expected_responses([messages.ButtonRequest, messages.Address]) client.set_expected_responses([messages.ButtonRequest, messages.Address])
btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH) btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
@ -51,8 +51,8 @@ def _check_no_pin(session: Session):
session.lock() session.lock()
assert session.features.pin_protection is False assert session.features.pin_protection is False
with session: with session.client as client:
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH) btc.get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
@ -63,13 +63,13 @@ def test_set_pin(session: Session):
_check_no_pin(session) _check_no_pin(session)
# Let's set new PIN # Let's set new PIN
with session, session.client as client: with session.client as client:
if client.layout_type is LayoutType.Caesar: if client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 4 br_count = 4
client.use_pin_sequence([PIN_MAX, PIN_MAX]) client.use_pin_sequence([PIN_MAX, PIN_MAX])
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest] * br_count + [messages.Success] [messages.ButtonRequest] * br_count + [messages.Success]
) )
device.change_pin(session) device.change_pin(session)
@ -86,13 +86,13 @@ def test_change_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's change PIN # Let's change PIN
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX]) client.use_pin_sequence([PIN4, PIN_MAX, PIN_MAX])
if client.layout_type is LayoutType.Caesar: if client.layout_type is LayoutType.Caesar:
br_count = 6 br_count = 6
else: else:
br_count = 5 br_count = 5
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest] * br_count [messages.ButtonRequest] * br_count
+ [messages.Success] # , messages.Features] + [messages.Success] # , messages.Features]
) )
@ -113,11 +113,9 @@ def test_remove_pin(session: Session):
_check_pin(session, PIN4) _check_pin(session, PIN4)
# Let's remove PIN # Let's remove PIN
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses( client.set_expected_responses([messages.ButtonRequest] * 3 + [messages.Success])
[messages.ButtonRequest] * 3 + [messages.Success]
)
device.change_pin(session, remove=True) device.change_pin(session, remove=True)
# Check that there's no PIN protection now # Check that there's no PIN protection now
@ -132,8 +130,8 @@ def test_set_failed(session: Session):
# Check that there's no PIN protection # Check that there's no PIN protection
_check_no_pin(session) _check_no_pin(session)
with session, session.client as client, pytest.raises(TrezorFailure): with session.client as client, pytest.raises(TrezorFailure):
IF = InputFlowNewCodeMismatch(client, PIN4, PIN60, what="pin") IF = InputFlowNewCodeMismatch(session.client, PIN4, PIN60, what="pin")
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.change_pin(session) device.change_pin(session)
@ -151,7 +149,7 @@ def test_change_failed(session: Session):
# Check current PIN value # Check current PIN value
_check_pin(session, PIN4) _check_pin(session, PIN4)
with session, session.client as client, pytest.raises(Cancelled): with session.client as client, pytest.raises(Cancelled):
IF = InputFlowCodeChangeFail(session, PIN4, "457891", "381847") IF = InputFlowCodeChangeFail(session, PIN4, "457891", "381847")
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
@ -170,8 +168,8 @@ def test_change_invalid_current(session: Session):
# Check current PIN value # Check current PIN value
_check_pin(session, PIN4) _check_pin(session, PIN4)
with session, session.client as client, pytest.raises(TrezorFailure): with session.client as client, pytest.raises(TrezorFailure):
IF = InputFlowWrongPIN(client, PIN60) IF = InputFlowWrongPIN(session.client, PIN60)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.change_pin(session) device.change_pin(session)
@ -200,7 +198,7 @@ def test_pin_menu_cancel_setup(session: Session):
# tap to confirm # tap to confirm
debug.click(debug.screen_buttons.tap_to_confirm()) debug.click(debug.screen_buttons.tap_to_confirm())
with session, session.client as client, pytest.raises(Cancelled): with session.client as client, pytest.raises(Cancelled):
client.set_input_flow(cancel_pin_setup_input_flow) client.set_input_flow(cancel_pin_setup_input_flow)
session.call(messages.ChangePin()) session.call(messages.ChangePin())
_check_no_pin(session) _check_no_pin(session)

View File

@ -19,13 +19,14 @@ from trezorlib.debuglink import SessionDebugWrapper as Session
def test_ping(session: Session): def test_ping(session: Session):
with session: client = session.client
session.set_expected_responses([messages.Success]) with client:
client.set_expected_responses([messages.Success])
res = session.call(messages.Ping(message="random data")) res = session.call(messages.Ping(message="random data"))
assert res.message == "random data" assert res.message == "random data"
with session: with client:
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall),
messages.Success, messages.Success,

View File

@ -45,7 +45,6 @@ def test_wipe_device(client: Client):
@pytest.mark.setup_client(pin=PIN4) @pytest.mark.setup_client(pin=PIN4)
def test_autolock_not_retained(session: Session): def test_autolock_not_retained(session: Session):
client = session.client client = session.client
with client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
device.apply_settings(session, auto_lock_delay_ms=10_000) device.apply_settings(session, auto_lock_delay_ms=10_000)
@ -57,7 +56,6 @@ def test_autolock_not_retained(session: Session):
assert client.features.auto_lock_delay_ms > 10_000 assert client.features.auto_lock_delay_ms > 10_000
with client:
client.use_pin_sequence([PIN4, PIN4]) client.use_pin_sequence([PIN4, PIN4])
device.setup( device.setup(
session, session,
@ -71,7 +69,7 @@ def test_autolock_not_retained(session: Session):
time.sleep(10.5) time.sleep(10.5)
session = client.get_session() session = client.get_session()
with session, client: with session.client as client:
# after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked # after sleeping for the pre-wipe autolock amount, Trezor must still be unlocked
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)

View File

@ -33,17 +33,17 @@ pytestmark = pytest.mark.setup_client(pin=PIN4)
@pytest.mark.setup_client(pin=None) @pytest.mark.setup_client(pin=None)
def test_no_protection(session: Session): def test_no_protection(session: Session):
with session: with session.client as client:
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)
def test_correct_pin(session: Session): def test_correct_pin(session: Session):
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
# Expected responses differ between T1 and TT # Expected responses differ between T1 and TT
is_t1 = session.model is models.T1B1 is_t1 = session.model is models.T1B1
session.set_expected_responses( client.set_expected_responses(
[ [
(is_t1, messages.PinMatrixRequest), (is_t1, messages.PinMatrixRequest),
( (
@ -65,10 +65,10 @@ def test_incorrect_pin_t1(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_incorrect_pin_t2(session: Session): def test_incorrect_pin_t2(session: Session):
with session, session.client as client: with session.client as client:
# After first incorrect attempt, TT will not raise an error, but instead ask for another attempt # After first incorrect attempt, TT will not raise an error, but instead ask for another attempt
client.use_pin_sequence([BAD_PIN, PIN4]) client.use_pin_sequence([BAD_PIN, PIN4])
session.set_expected_responses( client.set_expected_responses(
[ [
messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry),
messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry), messages.ButtonRequest(code=messages.ButtonRequestType.PinEntry),
@ -82,7 +82,7 @@ def test_incorrect_pin_t2(session: Session):
def test_exponential_backoff_t1(session: Session): def test_exponential_backoff_t1(session: Session):
for attempt in range(3): for attempt in range(3):
start = time.time() start = time.time()
with session, session.client as client, pytest.raises(PinException): with session.client as client, pytest.raises(PinException):
client.use_pin_sequence([BAD_PIN]) client.use_pin_sequence([BAD_PIN])
get_test_address(session) get_test_address(session)
check_pin_backoff_time(attempt, start) check_pin_backoff_time(attempt, start)

View File

@ -97,7 +97,7 @@ def test_passphrase_reporting(session: Session, passphrase):
"""On TT, passphrase_protection is a private setting, so a locked device should """On TT, passphrase_protection is a private setting, so a locked device should
report passphrase_protection=None. report passphrase_protection=None.
""" """
with session, session.client as client: with session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
device.apply_settings(session, use_passphrase=passphrase) device.apply_settings(session, use_passphrase=passphrase)
@ -164,7 +164,7 @@ def test_change_pin_t2(client: Client):
_pin_request(client), _pin_request(client),
_pin_request(client), _pin_request(client),
( (
session.client.layout_type is LayoutType.Caesar, client.layout_type is LayoutType.Caesar,
messages.ButtonRequest, messages.ButtonRequest,
), ),
_pin_request(client), _pin_request(client),
@ -238,7 +238,7 @@ def test_wipe_device(client: Client):
session = client.get_session() session = client.get_session()
client.set_expected_responses([messages.ButtonRequest, messages.Success]) client.set_expected_responses([messages.ButtonRequest, messages.Success])
device.wipe(session) device.wipe(session)
client = session.client.get_new_client() client = client.get_new_client()
session = client.get_seedless_session() session = client.get_seedless_session()
with client: with client:
client.set_expected_responses([messages.Features]) client.set_expected_responses([messages.Features])
@ -251,8 +251,8 @@ def test_wipe_device(client: Client):
def test_reset_device(session: Session): def test_reset_device(session: Session):
assert session.features.pin_protection is False assert session.features.pin_protection is False
assert session.features.passphrase_protection is False assert session.features.passphrase_protection is False
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest] [messages.ButtonRequest]
+ [messages.EntropyRequest] + [messages.EntropyRequest]
+ [messages.ButtonRequest] * 24 + [messages.ButtonRequest] * 24
@ -289,8 +289,8 @@ def test_recovery_device(session: Session, uninitialized_session=True):
assert session.features.pin_protection is False assert session.features.pin_protection is False
assert session.features.passphrase_protection is False assert session.features.passphrase_protection is False
session.client.use_mnemonic(MNEMONIC12) session.client.use_mnemonic(MNEMONIC12)
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest] [messages.ButtonRequest]
+ [messages.WordRequest] * 24 + [messages.WordRequest] * 24
+ [messages.Success] # , messages.Features] + [messages.Success] # , messages.Features]
@ -302,7 +302,7 @@ def test_recovery_device(session: Session, uninitialized_session=True):
False, False,
False, False,
"label", "label",
input_callback=session.client.mnemonic_callback, input_callback=client.mnemonic_callback,
) )
with pytest.raises(TrezorFailure): with pytest.raises(TrezorFailure):

View File

@ -34,12 +34,13 @@ pytestmark = pytest.mark.models("core")
@pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6)
def test_repeated_backup(session: Session): def test_repeated_backup(session: Session):
client = session.client
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
assert session.features.recovery_status == messages.RecoveryStatus.Nothing assert session.features.recovery_status == messages.RecoveryStatus.Nothing
# initial device backup # initial device backup
mnemonics = [] mnemonics = []
with session, session.client as client: with client:
IF = InputFlowSlip39BasicBackup(client, False) IF = InputFlowSlip39BasicBackup(client, False)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -56,7 +57,7 @@ def test_repeated_backup(session: Session):
device.backup(session) device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got # unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client: with client:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True client, mnemonics[:3], unlock_repeated_backup=True
) )
@ -69,7 +70,7 @@ def test_repeated_backup(session: Session):
assert session.features.recovery_status == messages.RecoveryStatus.Backup assert session.features.recovery_status == messages.RecoveryStatus.Backup
# we can now perform another backup # we can now perform another backup
with session, session.client as client: with client:
IF = InputFlowSlip39BasicBackup(client, False, repeated=True) IF = InputFlowSlip39BasicBackup(client, False, repeated=True)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -85,6 +86,7 @@ def test_repeated_backup(session: Session):
@pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20) @pytest.mark.setup_client(mnemonic=MNEMONIC_SLIP39_SINGLE_EXT_20)
def test_repeated_backup_upgrade_single(session: Session): def test_repeated_backup_upgrade_single(session: Session):
client = session.client
assert ( assert (
session.features.backup_availability == messages.BackupAvailability.NotAvailable session.features.backup_availability == messages.BackupAvailability.NotAvailable
) )
@ -92,7 +94,7 @@ def test_repeated_backup_upgrade_single(session: Session):
assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable assert session.features.backup_type == messages.BackupType.Slip39_Single_Extendable
# unlock repeated backup by entering the single share # unlock repeated backup by entering the single share
with session, session.client as client: with client:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True client, MNEMONIC_SLIP39_SINGLE_EXT_20, unlock_repeated_backup=True
) )
@ -105,7 +107,7 @@ def test_repeated_backup_upgrade_single(session: Session):
assert session.features.recovery_status == messages.RecoveryStatus.Backup assert session.features.recovery_status == messages.RecoveryStatus.Backup
# we can now perform another backup # we can now perform another backup
with session, session.client as client: with client:
IF = InputFlowSlip39BasicBackup(client, False, repeated=True) IF = InputFlowSlip39BasicBackup(client, False, repeated=True)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -123,12 +125,13 @@ def test_repeated_backup_upgrade_single(session: Session):
@pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6)
def test_repeated_backup_cancel(session: Session): def test_repeated_backup_cancel(session: Session):
client = session.client
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
assert session.features.recovery_status == messages.RecoveryStatus.Nothing assert session.features.recovery_status == messages.RecoveryStatus.Nothing
# initial device backup # initial device backup
mnemonics = [] mnemonics = []
with session, session.client as client: with client:
IF = InputFlowSlip39BasicBackup(client, False) IF = InputFlowSlip39BasicBackup(client, False)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -145,7 +148,7 @@ def test_repeated_backup_cancel(session: Session):
device.backup(session) device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got # unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client: with client:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True client, mnemonics[:3], unlock_repeated_backup=True
) )
@ -157,7 +160,7 @@ def test_repeated_backup_cancel(session: Session):
) )
assert session.features.recovery_status == messages.RecoveryStatus.Backup assert session.features.recovery_status == messages.RecoveryStatus.Backup
layout = session.client.debug.read_layout() layout = client.debug.read_layout()
assert TR.recovery__unlock_repeated_backup in layout.text_content() assert TR.recovery__unlock_repeated_backup in layout.text_content()
# send a Cancel message # send a Cancel message
@ -178,12 +181,13 @@ def test_repeated_backup_cancel(session: Session):
@pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6) @pytest.mark.setup_client(needs_backup=True, mnemonic=MNEMONIC_SLIP39_BASIC_20_3of6)
def test_repeated_backup_send_disallowed_message(session: Session): def test_repeated_backup_send_disallowed_message(session: Session):
client = session.client
assert session.features.backup_availability == messages.BackupAvailability.Required assert session.features.backup_availability == messages.BackupAvailability.Required
assert session.features.recovery_status == messages.RecoveryStatus.Nothing assert session.features.recovery_status == messages.RecoveryStatus.Nothing
# initial device backup # initial device backup
mnemonics = [] mnemonics = []
with session, session.client as client: with client:
IF = InputFlowSlip39BasicBackup(client, False) IF = InputFlowSlip39BasicBackup(client, False)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
device.backup(session) device.backup(session)
@ -200,7 +204,7 @@ def test_repeated_backup_send_disallowed_message(session: Session):
device.backup(session) device.backup(session)
# unlock repeated backup by entering 3 of the 5 shares we have got # unlock repeated backup by entering 3 of the 5 shares we have got
with session, session.client as client: with client:
IF = InputFlowSlip39BasicRecoveryDryRun( IF = InputFlowSlip39BasicRecoveryDryRun(
client, mnemonics[:3], unlock_repeated_backup=True client, mnemonics[:3], unlock_repeated_backup=True
) )
@ -212,7 +216,7 @@ def test_repeated_backup_send_disallowed_message(session: Session):
) )
assert session.features.recovery_status == messages.RecoveryStatus.Backup assert session.features.recovery_status == messages.RecoveryStatus.Backup
layout = session.client.debug.read_layout() layout = client.debug.read_layout()
assert TR.recovery__unlock_repeated_backup in layout.text_content() assert TR.recovery__unlock_repeated_backup in layout.text_content()
# send a GetAddress message # send a GetAddress message
@ -233,8 +237,7 @@ def test_repeated_backup_send_disallowed_message(session: Session):
# we are still on the confirmation screen! # we are still on the confirmation screen!
assert ( assert (
TR.recovery__unlock_repeated_backup TR.recovery__unlock_repeated_backup in client.debug.read_layout().text_content()
in session.client.debug.read_layout().text_content()
) )
with pytest.raises(exceptions.Cancelled): with pytest.raises(exceptions.Cancelled):
session.call(messages.Cancel()) session.call(messages.Cancel())

View File

@ -36,7 +36,8 @@ def test_sd_format(session: Session):
@pytest.mark.sd_card(formatted=False) @pytest.mark.sd_card(formatted=False)
def test_sd_no_format(session: Session): def test_sd_no_format(session: Session):
debug = session.client.debug client = session.client
debug = client.debug
def input_flow(): def input_flow():
yield # enable SD protection? yield # enable SD protection?
@ -45,7 +46,7 @@ def test_sd_no_format(session: Session):
yield # format SD card yield # format SD card
debug.press_no() debug.press_no()
with session, session.client as client, pytest.raises(TrezorFailure) as e: with client, pytest.raises(TrezorFailure) as e:
client.set_input_flow(input_flow) client.set_input_flow(input_flow)
device.sd_protect(session, Op.ENABLE) device.sd_protect(session, Op.ENABLE)
@ -55,7 +56,8 @@ def test_sd_no_format(session: Session):
@pytest.mark.sd_card @pytest.mark.sd_card
@pytest.mark.setup_client(pin=PIN) @pytest.mark.setup_client(pin=PIN)
def test_sd_protect_unlock(session: Session): def test_sd_protect_unlock(session: Session):
debug = session.client.debug client = session.client
debug = client.debug
layout = debug.read_layout layout = debug.read_layout
def input_flow_enable_sd_protect(): def input_flow_enable_sd_protect():
@ -76,7 +78,7 @@ def test_sd_protect_unlock(session: Session):
assert TR.sd_card__enabled in layout().text_content() assert TR.sd_card__enabled in layout().text_content()
debug.press_yes() debug.press_yes()
with session, session.client as client: with client:
client.watch_layout() client.watch_layout()
client.set_input_flow(input_flow_enable_sd_protect) client.set_input_flow(input_flow_enable_sd_protect)
device.sd_protect(session, Op.ENABLE) device.sd_protect(session, Op.ENABLE)
@ -102,7 +104,7 @@ def test_sd_protect_unlock(session: Session):
assert TR.pin__changed in layout().text_content() assert TR.pin__changed in layout().text_content()
debug.press_yes() debug.press_yes()
with session, session.client as client: with client:
client.watch_layout() client.watch_layout()
client.set_input_flow(input_flow_change_pin) client.set_input_flow(input_flow_change_pin)
device.change_pin(session) device.change_pin(session)
@ -125,7 +127,7 @@ def test_sd_protect_unlock(session: Session):
) )
debug.press_no() # close debug.press_no() # close
with session, session.client as client, pytest.raises(TrezorFailure) as e: with client, pytest.raises(TrezorFailure) as e:
client.watch_layout() client.watch_layout()
client.set_input_flow(input_flow_change_pin_format) client.set_input_flow(input_flow_change_pin_format)
device.change_pin(session) device.change_pin(session)

View File

@ -71,9 +71,9 @@ def test_clear_session(client: Client):
assert _get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N).xpub == XPUB
session.resume() session.resume()
with session: with client:
# pin and passphrase are cached # pin and passphrase are cached
session.set_expected_responses(cached_responses) client.set_expected_responses(cached_responses)
assert _get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N).xpub == XPUB
session.lock() session.lock()
@ -87,9 +87,9 @@ def test_clear_session(client: Client):
assert _get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N).xpub == XPUB
session.resume() session.resume()
with session: with client:
# pin and passphrase are cached # pin and passphrase are cached
session.set_expected_responses(cached_responses) client.set_expected_responses(cached_responses)
assert _get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N).xpub == XPUB
@ -100,8 +100,8 @@ def test_end_session(client: Client):
assert session.id is not None assert session.id is not None
# get_address will succeed # get_address will succeed
with session: with client:
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)
session.end() session.end()
@ -113,13 +113,13 @@ def test_end_session(client: Client):
session = client.get_session() session = client.get_session()
assert session.id is not None assert session.id is not None
with session: with client:
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
get_test_address(session) get_test_address(session)
with session as session: with client:
# end_session should succeed on empty session too # end_session should succeed on empty session too
session.set_expected_responses([messages.Success] * 2) client.set_expected_responses([messages.Success] * 2)
session.end() session.end()
session.end() session.end()
@ -162,8 +162,8 @@ def test_end_session_only_current(client: Client):
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
def test_session_recycling(client: Client): def test_session_recycling(client: Client):
session = client.get_session(passphrase="TREZOR") session = client.get_session(passphrase="TREZOR")
with session: with client:
session.set_expected_responses([messages.Address]) client.set_expected_responses([messages.Address])
address = get_test_address(session) address = get_test_address(session)
# create and close 100 sessions - more than the session limit # create and close 100 sessions - more than the session limit
@ -172,7 +172,7 @@ def test_session_recycling(client: Client):
session_x.end() session_x.end()
# it should still be possible to resume the original session # it should still be possible to resume the original session
with client, session: with client:
# passphrase should still be cached # passphrase should still be cached
expected_responses = [messages.Address] * 3 expected_responses = [messages.Address] * 3
if client.protocol_version == ProtocolVersion.V1: if client.protocol_version == ProtocolVersion.V1:

View File

@ -65,8 +65,8 @@ def _get_xpub(
else: else:
expected_responses = [messages.PublicKey] expected_responses = [messages.PublicKey]
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
result = session.call_raw(XPUB_REQUEST) result = session.call_raw(XPUB_REQUEST)
if passphrase is not None: if passphrase is not None:
result = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) result = session.call_raw(messages.PassphraseAck(passphrase=passphrase))
@ -430,7 +430,7 @@ def test_passphrase_length(client: Client):
def test_hide_passphrase_from_host(client: Client): def test_hide_passphrase_from_host(client: Client):
# Without safety checks, turning it on fails # Without safety checks, turning it on fails
session = client.get_seedless_session() session = client.get_seedless_session()
with pytest.raises(TrezorFailure, match="Safety checks are strict"), client: with pytest.raises(TrezorFailure, match="Safety checks are strict"):
device.apply_settings(session, hide_passphrase_from_host=True) device.apply_settings(session, hide_passphrase_from_host=True)
device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily) device.apply_settings(session, safety_checks=SafetyCheckLevel.PromptTemporarily)
@ -440,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client):
passphrase = "abc" passphrase = "abc"
session = _get_session(client) session = _get_session(client)
with session: with client:
def input_flow(): def input_flow():
yield yield
@ -455,9 +455,9 @@ def test_hide_passphrase_from_host(client: Client):
else: else:
raise KeyError raise KeyError
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow) client.set_input_flow(input_flow)
session.set_expected_responses( client.set_expected_responses(
[ [
messages.PassphraseRequest, messages.PassphraseRequest,
messages.ButtonRequest, messages.ButtonRequest,
@ -476,7 +476,7 @@ def test_hide_passphrase_from_host(client: Client):
# Starting new session, otherwise the passphrase would be cached # Starting new session, otherwise the passphrase would be cached
session = _get_session(client) session = _get_session(client)
with client, session: with client:
def input_flow(): def input_flow():
yield yield
@ -491,9 +491,9 @@ def test_hide_passphrase_from_host(client: Client):
assert passphrase in client.debug.read_layout().text_content() assert passphrase in client.debug.read_layout().text_content()
client.debug.press_yes() client.debug.press_yes()
client.watch_layout() session.client.watch_layout()
client.set_input_flow(input_flow) client.set_input_flow(input_flow)
session.set_expected_responses( client.set_expected_responses(
[ [
messages.PassphraseRequest, messages.PassphraseRequest,
messages.ButtonRequest, messages.ButtonRequest,

View File

@ -45,7 +45,7 @@ def test_tezos_get_address_chunkify_details(
session: Session, path: str, expected_address: str session: Session, path: str, expected_address: str
): ):
with session.client as client: with session.client as client:
IF = InputFlowShowAddressQRCode(client) IF = InputFlowShowAddressQRCode(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
address = get_address( address = get_address(
session, parse_path(path), show_display=True, chunkify=True session, parse_path(path), show_display=True, chunkify=True

View File

@ -33,7 +33,7 @@ pytestmark = [
def test_tezos_sign_tx_proposal(session: Session): def test_tezos_sign_tx_proposal(session: Session):
with session: with session.client:
resp = tezos.sign_tx( resp = tezos.sign_tx(
session, session,
TEZOS_PATH_10, TEZOS_PATH_10,
@ -64,7 +64,7 @@ def test_tezos_sign_tx_proposal(session: Session):
def test_tezos_sign_tx_multiple_proposals(session: Session): def test_tezos_sign_tx_multiple_proposals(session: Session):
with session: with session.client:
resp = tezos.sign_tx( resp = tezos.sign_tx(
session, session,
TEZOS_PATH_10, TEZOS_PATH_10,

View File

@ -31,8 +31,8 @@ RK_CAPACITY = 100
@pytest.mark.altcoin @pytest.mark.altcoin
@pytest.mark.setup_client(mnemonic=MNEMONIC12) @pytest.mark.setup_client(mnemonic=MNEMONIC12)
def test_add_remove(session: Session): def test_add_remove(session: Session):
with session, session.client as client: with session.client as client:
IF = InputFlowFidoConfirm(client) IF = InputFlowFidoConfirm(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
# Remove index 0 should fail. # Remove index 0 should fail.

View File

@ -95,8 +95,8 @@ def test_spend_v4_input(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -143,8 +143,8 @@ def test_send_to_multisig(session: Session):
script_type=messages.OutputScriptType.PAYTOSCRIPTHASH, script_type=messages.OutputScriptType.PAYTOSCRIPTHASH,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -190,8 +190,8 @@ def test_spend_v5_input(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -243,8 +243,8 @@ def test_one_two(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -301,8 +301,8 @@ def test_unified_address(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_output(0), request_output(0),
@ -365,8 +365,8 @@ def test_external_presigned(session: Session):
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
) )
with session: with session.client as client:
session.set_expected_responses( client.set_expected_responses(
[ [
request_input(0), request_input(0),
request_input(1), request_input(1),
@ -489,8 +489,8 @@ def test_spend_multisig(session: Session):
request_finished(), request_finished(),
] ]
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
signatures1, _ = btc.sign_tx( signatures1, _ = btc.sign_tx(
session, session,
"Zcash Testnet", "Zcash Testnet",
@ -529,8 +529,8 @@ def test_spend_multisig(session: Session):
multisig=multisig, multisig=multisig,
) )
with session: with session.client as client:
session.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
signatures2, serialized_tx = btc.sign_tx( signatures2, serialized_tx = btc.sign_tx(
session, session,
"Zcash Testnet", "Zcash Testnet",

View File

@ -50,16 +50,18 @@ class InputFlowBase:
# There could be one common input flow for all models # There could be one common input flow for all models
if hasattr(self, "input_flow_common"): if hasattr(self, "input_flow_common"):
return getattr(self, "input_flow_common") flow = getattr(self, "input_flow_common")
elif self.client.layout_type is LayoutType.Bolt: elif self.client.layout_type is LayoutType.Bolt:
return self.input_flow_bolt flow = self.input_flow_bolt
elif self.client.layout_type is LayoutType.Caesar: elif self.client.layout_type is LayoutType.Caesar:
return self.input_flow_caesar flow = self.input_flow_caesar
elif self.client.layout_type is LayoutType.Delizia: elif self.client.layout_type is LayoutType.Delizia:
return self.input_flow_delizia flow = self.input_flow_delizia
else: else:
raise ValueError("Unknown model") raise ValueError("Unknown model")
return flow
def input_flow_bolt(self) -> BRGeneratorType: def input_flow_bolt(self) -> BRGeneratorType:
"""Special for TT""" """Special for TT"""
raise NotImplementedError raise NotImplementedError
@ -371,7 +373,7 @@ class InputFlowSignMessageInfo(InputFlowBase):
self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1]) self.debug.click(self.client.debug.screen_buttons.vertical_menu_items()[1])
# address mismatch? yes! # address mismatch? yes!
self.debug.swipe_up() self.debug.swipe_up()
yield yield # ?
class InputFlowShowAddressQRCode(InputFlowBase): class InputFlowShowAddressQRCode(InputFlowBase):

View File

@ -11,33 +11,37 @@ WIPE_CODE = "9876"
def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None: def setup_device_legacy(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session()) session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client() client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device( debuglink.load_device(
client.get_seedless_session(), session,
MNEMONIC12, MNEMONIC12,
pin, pin,
passphrase_protection=False, passphrase_protection=False,
label="WIPECODE", label="WIPECODE",
) )
with client: with session.client as client:
client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE]) client.use_pin_sequence([PIN, WIPE_CODE, WIPE_CODE])
device.change_wipe_code(client.get_seedless_session()) device.change_wipe_code(client.get_seedless_session())
def setup_device_core(client: Client, pin: str, wipe_code: str) -> None: def setup_device_core(client: Client, pin: str, wipe_code: str) -> None:
device.wipe(client.get_seedless_session()) session = client.get_seedless_session()
device.wipe(session)
client = client.get_new_client() client = client.get_new_client()
session = client.get_seedless_session()
debuglink.load_device( debuglink.load_device(
client.get_seedless_session(), session,
MNEMONIC12, MNEMONIC12,
pin, pin,
passphrase_protection=False, passphrase_protection=False,
label="WIPECODE", label="WIPECODE",
) )
with client: with session.client as client:
client.use_pin_sequence([pin, wipe_code, wipe_code]) client.use_pin_sequence([pin, wipe_code, wipe_code])
device.change_wipe_code(client.get_seedless_session()) device.change_wipe_code(client.get_seedless_session())

View File

@ -69,7 +69,7 @@ def set_language(session: Session, lang: str, *, force: bool = True):
language_data = b"" language_data = b""
else: else:
language_data = build_and_sign_blob(lang, session) language_data = build_and_sign_blob(lang, session)
with session: with session.client:
if not session.features.language.startswith(lang) or force: if not session.features.language.startswith(lang) or force:
device.change_language(session, language_data) # type: ignore device.change_language(session, language_data) # type: ignore
_CURRENT_TRANSLATION.TR = TRANSLATIONS[lang] _CURRENT_TRANSLATION.TR = TRANSLATIONS[lang]