mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 07:20:56 +00:00
test: update test framework
[no changelog]
This commit is contained in:
parent
c30769bac5
commit
92d80a3653
@ -11,6 +11,7 @@ multisig
|
|||||||
nem
|
nem
|
||||||
ontology
|
ontology
|
||||||
peercoin
|
peercoin
|
||||||
|
protocol
|
||||||
ripple
|
ripple
|
||||||
sd_card
|
sd_card
|
||||||
solana
|
solana
|
||||||
|
@ -34,8 +34,8 @@ if TYPE_CHECKING:
|
|||||||
from _pytest.mark.structures import MarkDecorator
|
from _pytest.mark.structures import MarkDecorator
|
||||||
|
|
||||||
from trezorlib.debuglink import DebugLink
|
from trezorlib.debuglink import DebugLink
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
|
||||||
from trezorlib.messages import ButtonRequest
|
from trezorlib.messages import ButtonRequest
|
||||||
|
from trezorlib.transport.session import Session
|
||||||
|
|
||||||
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
|
PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")]
|
||||||
|
|
||||||
@ -338,10 +338,10 @@ def check_pin_backoff_time(attempts: int, start: float) -> None:
|
|||||||
assert got >= expected
|
assert got >= expected
|
||||||
|
|
||||||
|
|
||||||
def get_test_address(client: "Client") -> str:
|
def get_test_address(session: "Session") -> str:
|
||||||
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
|
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
|
||||||
protected call, or to identify the root secret (seed+passphrase)"""
|
protected call, or to identify the root secret (seed+passphrase)"""
|
||||||
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
|
return btc.get_address(session, "Testnet", TEST_ADDRESS_N)
|
||||||
|
|
||||||
|
|
||||||
def compact_size(n: int) -> bytes:
|
def compact_size(n: int) -> bytes:
|
||||||
@ -380,5 +380,5 @@ def swipe_till_the_end(debug: "DebugLink", br: messages.ButtonRequest) -> None:
|
|||||||
debug.swipe_up()
|
debug.swipe_up()
|
||||||
|
|
||||||
|
|
||||||
def is_core(client: "Client") -> bool:
|
def is_core(session: "Session") -> bool:
|
||||||
return client.model is not models.T1B1
|
return session.model is not models.T1B1
|
||||||
|
@ -20,17 +20,22 @@ import os
|
|||||||
import typing as t
|
import typing as t
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
|
import cryptography
|
||||||
import pytest
|
import pytest
|
||||||
import xdist
|
import xdist
|
||||||
from _pytest.python import IdMaker
|
from _pytest.python import IdMaker
|
||||||
from _pytest.reports import TestReport
|
from _pytest.reports import TestReport
|
||||||
|
|
||||||
from trezorlib import debuglink, log, models
|
from trezorlib import debuglink, log, models
|
||||||
|
from trezorlib.client import ProtocolVersion
|
||||||
|
from trezorlib.debuglink import SessionDebugWrapper
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
from trezorlib.device import apply_settings
|
from trezorlib.device import apply_settings
|
||||||
from trezorlib.device import wipe as wipe_device
|
from trezorlib.device import wipe as wipe_device
|
||||||
from trezorlib.transport import enumerate_devices, get_transport
|
from trezorlib.transport import enumerate_devices, get_transport
|
||||||
|
from trezorlib.transport.thp.protocol_v1 import ProtocolV1
|
||||||
|
|
||||||
# register rewrites before importing from local package
|
# register rewrites before importing from local package
|
||||||
# so that we see details of failed asserts from this module
|
# so that we see details of failed asserts from this module
|
||||||
@ -135,6 +140,10 @@ def emulator(request: pytest.FixtureRequest) -> t.Generator["Emulator", None, No
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def _raw_client(request: pytest.FixtureRequest) -> Client:
|
def _raw_client(request: pytest.FixtureRequest) -> Client:
|
||||||
|
return _get_raw_client(request)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_raw_client(request: pytest.FixtureRequest) -> Client:
|
||||||
# In case tests run in parallel, each process has its own emulator/client.
|
# In case tests run in parallel, each process has its own emulator/client.
|
||||||
# Requesting the emulator fixture only if relevant.
|
# Requesting the emulator fixture only if relevant.
|
||||||
if request.session.config.getoption("control_emulators"):
|
if request.session.config.getoption("control_emulators"):
|
||||||
@ -273,6 +282,29 @@ def client(
|
|||||||
if _raw_client.model not in models_filter:
|
if _raw_client.model not in models_filter:
|
||||||
pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}")
|
pytest.skip(f"Skipping test for model {_raw_client.model.internal_name}")
|
||||||
|
|
||||||
|
protocol_marker: Mark | None = request.node.get_closest_marker("protocol")
|
||||||
|
if protocol_marker:
|
||||||
|
args = protocol_marker.args
|
||||||
|
protocol_version = _raw_client.protocol_version
|
||||||
|
|
||||||
|
if (
|
||||||
|
protocol_version == ProtocolVersion.PROTOCOL_V1
|
||||||
|
and "protocol_v1" not in args
|
||||||
|
):
|
||||||
|
pytest.xfail(
|
||||||
|
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
protocol_version == ProtocolVersion.PROTOCOL_V2
|
||||||
|
and "protocol_v2" not in args
|
||||||
|
):
|
||||||
|
pytest.xfail(
|
||||||
|
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||||
|
pass
|
||||||
sd_marker = request.node.get_closest_marker("sd_card")
|
sd_marker = request.node.get_closest_marker("sd_card")
|
||||||
if sd_marker and not _raw_client.features.sd_card_present:
|
if sd_marker and not _raw_client.features.sd_card_present:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -283,11 +315,12 @@ def client(
|
|||||||
|
|
||||||
test_ui = request.config.getoption("ui")
|
test_ui = request.config.getoption("ui")
|
||||||
|
|
||||||
_raw_client.reset_debug_features()
|
_raw_client.reset_debug_features(new_management_session=True)
|
||||||
_raw_client.open()
|
_raw_client.open()
|
||||||
|
if isinstance(_raw_client.protocol, ProtocolV1):
|
||||||
try:
|
try:
|
||||||
_raw_client.sync_responses()
|
_raw_client.sync_responses()
|
||||||
_raw_client.init_device()
|
# TODO _raw_client.init_device()
|
||||||
except Exception:
|
except Exception:
|
||||||
request.session.shouldstop = "Failed to communicate with Trezor"
|
request.session.shouldstop = "Failed to communicate with Trezor"
|
||||||
pytest.fail("Failed to communicate with Trezor")
|
pytest.fail("Failed to communicate with Trezor")
|
||||||
@ -303,13 +336,34 @@ def client(
|
|||||||
should_format = sd_marker.kwargs.get("formatted", True)
|
should_format = sd_marker.kwargs.get("formatted", True)
|
||||||
_raw_client.debug.erase_sd_card(format=should_format)
|
_raw_client.debug.erase_sd_card(format=should_format)
|
||||||
|
|
||||||
wipe_device(_raw_client)
|
while True:
|
||||||
|
try:
|
||||||
|
session = _raw_client.get_management_session()
|
||||||
|
wipe_device(session)
|
||||||
|
sleep(1.5) # Makes tests more stable (wait for wipe to finish)
|
||||||
|
break
|
||||||
|
except cryptography.exceptions.InvalidTag:
|
||||||
|
# Get a new client
|
||||||
|
_raw_client = _get_raw_client(request)
|
||||||
|
|
||||||
|
from trezorlib.transport.thp.channel_database import get_channel_db
|
||||||
|
|
||||||
|
get_channel_db().clear_stored_channels()
|
||||||
|
_raw_client.protocol = None
|
||||||
|
_raw_client.__init__(
|
||||||
|
transport=_raw_client.transport,
|
||||||
|
auto_interact=_raw_client.debug.allow_interactions,
|
||||||
|
)
|
||||||
|
if not _raw_client.features.bootloader_mode:
|
||||||
|
_raw_client.refresh_features()
|
||||||
|
|
||||||
# Load language again, as it got erased in wipe
|
# Load language again, as it got erased in wipe
|
||||||
if _raw_client.model is not models.T1B1:
|
if _raw_client.model is not models.T1B1:
|
||||||
lang = request.session.config.getoption("lang") or "en"
|
lang = request.session.config.getoption("lang") or "en"
|
||||||
assert isinstance(lang, str)
|
assert isinstance(lang, str)
|
||||||
translations.set_language(_raw_client, lang)
|
translations.set_language(
|
||||||
|
SessionDebugWrapper(_raw_client.get_management_session()), lang
|
||||||
|
)
|
||||||
|
|
||||||
setup_params = dict(
|
setup_params = dict(
|
||||||
uninitialized=False,
|
uninitialized=False,
|
||||||
@ -327,10 +381,10 @@ def client(
|
|||||||
use_passphrase = setup_params["passphrase"] is True or isinstance(
|
use_passphrase = setup_params["passphrase"] is True or isinstance(
|
||||||
setup_params["passphrase"], str
|
setup_params["passphrase"], str
|
||||||
)
|
)
|
||||||
|
|
||||||
if not setup_params["uninitialized"]:
|
if not setup_params["uninitialized"]:
|
||||||
|
session = _raw_client.get_management_session(new_session=True)
|
||||||
debuglink.load_device(
|
debuglink.load_device(
|
||||||
_raw_client,
|
session,
|
||||||
mnemonic=setup_params["mnemonic"], # type: ignore
|
mnemonic=setup_params["mnemonic"], # type: ignore
|
||||||
pin=setup_params["pin"], # type: ignore
|
pin=setup_params["pin"], # type: ignore
|
||||||
passphrase_protection=use_passphrase,
|
passphrase_protection=use_passphrase,
|
||||||
@ -338,14 +392,16 @@ def client(
|
|||||||
needs_backup=setup_params["needs_backup"], # type: ignore
|
needs_backup=setup_params["needs_backup"], # type: ignore
|
||||||
no_backup=setup_params["no_backup"], # type: ignore
|
no_backup=setup_params["no_backup"], # type: ignore
|
||||||
)
|
)
|
||||||
|
if setup_params["pin"] is not None:
|
||||||
|
_raw_client._has_setup_pin = True
|
||||||
|
|
||||||
if request.node.get_closest_marker("experimental"):
|
if request.node.get_closest_marker("experimental"):
|
||||||
apply_settings(_raw_client, experimental_features=True)
|
apply_settings(session, experimental_features=True)
|
||||||
|
|
||||||
if use_passphrase and isinstance(setup_params["passphrase"], str):
|
if use_passphrase and isinstance(setup_params["passphrase"], str):
|
||||||
_raw_client.use_passphrase(setup_params["passphrase"])
|
_raw_client.use_passphrase(setup_params["passphrase"])
|
||||||
|
|
||||||
_raw_client.clear_session()
|
# TODO _raw_client.clear_session()
|
||||||
|
|
||||||
with ui_tests.screen_recording(_raw_client, request):
|
with ui_tests.screen_recording(_raw_client, request):
|
||||||
yield _raw_client
|
yield _raw_client
|
||||||
@ -353,6 +409,29 @@ def client(
|
|||||||
_raw_client.close()
|
_raw_client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def session(
|
||||||
|
request: pytest.FixtureRequest, client: Client
|
||||||
|
) -> t.Generator[SessionDebugWrapper, None, None]:
|
||||||
|
if bool(request.node.get_closest_marker("uninitialized_session")):
|
||||||
|
session = client.get_management_session()
|
||||||
|
else:
|
||||||
|
derive_cardano = bool(request.node.get_closest_marker("cardano"))
|
||||||
|
passphrase = client.passphrase or ""
|
||||||
|
session = client.get_session(
|
||||||
|
derive_cardano=derive_cardano, passphrase=passphrase
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
wrapped_session = SessionDebugWrapper(session)
|
||||||
|
if client._has_setup_pin:
|
||||||
|
wrapped_session.lock()
|
||||||
|
yield wrapped_session
|
||||||
|
finally:
|
||||||
|
pass
|
||||||
|
# TODO
|
||||||
|
# session.end()
|
||||||
|
|
||||||
|
|
||||||
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
|
def _is_main_runner(session_or_request: pytest.Session | pytest.FixtureRequest) -> bool:
|
||||||
"""Return True if the current process is the main test runner.
|
"""Return True if the current process is the main test runner.
|
||||||
|
|
||||||
@ -463,6 +542,10 @@ def pytest_configure(config: "Config") -> None:
|
|||||||
"markers",
|
"markers",
|
||||||
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
|
'setup_client(mnemonic="all all all...", pin=None, passphrase=False, uninitialized=False): configure the client instance',
|
||||||
)
|
)
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
"uninitialized_session: use uninitialized session instance",
|
||||||
|
)
|
||||||
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
|
with open(os.path.join(os.path.dirname(__file__), "REGISTERED_MARKERS")) as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
config.addinivalue_line("markers", line.strip())
|
config.addinivalue_line("markers", line.strip())
|
||||||
|
@ -48,7 +48,9 @@ class BackgroundDeviceHandler:
|
|||||||
self.client.watch_layout(True)
|
self.client.watch_layout(True)
|
||||||
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
|
self.client.debug.input_wait_type = DebugWaitType.CURRENT_LAYOUT
|
||||||
|
|
||||||
def run(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
|
def run_with_session(
|
||||||
|
self, function: Callable[..., Any], *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
"""Runs some function that interacts with a device.
|
"""Runs some function that interacts with a device.
|
||||||
|
|
||||||
Makes sure the UI is updated before returning.
|
Makes sure the UI is updated before returning.
|
||||||
@ -58,15 +60,30 @@ class BackgroundDeviceHandler:
|
|||||||
|
|
||||||
# wait for the first UI change triggered by the task running in the background
|
# wait for the first UI change triggered by the task running in the background
|
||||||
with self.debuglink().wait_for_layout_change():
|
with self.debuglink().wait_for_layout_change():
|
||||||
self.task = self._pool.submit(function, self.client, *args, **kwargs)
|
session = self.client.get_session()
|
||||||
|
self.task = self._pool.submit(function, session, *args, **kwargs)
|
||||||
|
|
||||||
|
def run_with_provided_session(
|
||||||
|
self, session, function: Callable[..., Any], *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Runs some function that interacts with a device.
|
||||||
|
|
||||||
|
Makes sure the UI is updated before returning.
|
||||||
|
"""
|
||||||
|
if self.task is not None:
|
||||||
|
raise RuntimeError("Wait for previous task first")
|
||||||
|
|
||||||
|
# wait for the first UI change triggered by the task running in the background
|
||||||
|
with self.debuglink().wait_for_layout_change():
|
||||||
|
self.task = self._pool.submit(function, session, *args, **kwargs)
|
||||||
|
|
||||||
def kill_task(self) -> None:
|
def kill_task(self) -> None:
|
||||||
if self.task is not None:
|
if self.task is not None:
|
||||||
# Force close the client, which should raise an exception in a client
|
# Force close the client, which should raise an exception in a client
|
||||||
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
||||||
# a close() method.
|
# a close() method.
|
||||||
while self.client.session_counter > 0:
|
# while self.client.session_counter > 0:
|
||||||
self.client.close()
|
# self.client.close()
|
||||||
try:
|
try:
|
||||||
self.task.result(timeout=1)
|
self.task.result(timeout=1)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -90,7 +107,7 @@ class BackgroundDeviceHandler:
|
|||||||
def features(self) -> "Features":
|
def features(self) -> "Features":
|
||||||
if self.task is not None:
|
if self.task is not None:
|
||||||
raise RuntimeError("Cannot query features while task is running")
|
raise RuntimeError("Cannot query features while task is running")
|
||||||
self.client.init_device()
|
self.client.refresh_features()
|
||||||
return self.client.features
|
return self.client.features
|
||||||
|
|
||||||
def debuglink(self) -> "DebugLink":
|
def debuglink(self) -> "DebugLink":
|
||||||
|
@ -16,6 +16,7 @@ from typing import Callable, Generator
|
|||||||
|
|
||||||
from trezorlib import messages
|
from trezorlib import messages
|
||||||
from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType
|
from trezorlib.debuglink import DebugLink, LayoutContent, LayoutType
|
||||||
|
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
from trezorlib.debuglink import multipage_content
|
from trezorlib.debuglink import multipage_content
|
||||||
|
|
||||||
@ -129,13 +130,15 @@ class InputFlowNewCodeMismatch(InputFlowBase):
|
|||||||
|
|
||||||
|
|
||||||
class InputFlowCodeChangeFail(InputFlowBase):
|
class InputFlowCodeChangeFail(InputFlowBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, client: Client, current_pin: str, new_pin_1: str, new_pin_2: str
|
self, session: Session, current_pin: str, new_pin_1: str, new_pin_2: str
|
||||||
):
|
):
|
||||||
super().__init__(client)
|
super().__init__(session.client)
|
||||||
self.current_pin = current_pin
|
self.current_pin = current_pin
|
||||||
self.new_pin_1 = new_pin_1
|
self.new_pin_1 = new_pin_1
|
||||||
self.new_pin_2 = new_pin_2
|
self.new_pin_2 = new_pin_2
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield # do you want to change pin?
|
yield # do you want to change pin?
|
||||||
@ -150,7 +153,7 @@ class InputFlowCodeChangeFail(InputFlowBase):
|
|||||||
|
|
||||||
# failed retry
|
# failed retry
|
||||||
yield # enter current pin again
|
yield # enter current pin again
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowWrongPIN(InputFlowBase):
|
class InputFlowWrongPIN(InputFlowBase):
|
||||||
@ -1880,9 +1883,11 @@ class InputFlowBip39RecoveryDryRun(InputFlowBase):
|
|||||||
|
|
||||||
|
|
||||||
class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
|
class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
|
||||||
def __init__(self, client: Client):
|
|
||||||
super().__init__(client)
|
def __init__(self, session: Session):
|
||||||
|
super().__init__(session.client)
|
||||||
self.invalid_mnemonic = ["stick"] * 12
|
self.invalid_mnemonic = ["stick"] * 12
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield from self.REC.confirm_dry_run()
|
yield from self.REC.confirm_dry_run()
|
||||||
@ -1891,7 +1896,7 @@ class InputFlowBip39RecoveryDryRunInvalid(InputFlowBase):
|
|||||||
yield from self.REC.warning_invalid_recovery_seed()
|
yield from self.REC.warning_invalid_recovery_seed()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowBip39Recovery(InputFlowBase):
|
class InputFlowBip39Recovery(InputFlowBase):
|
||||||
@ -1974,15 +1979,17 @@ class InputFlowSlip39AdvancedRecoveryNoAbort(InputFlowBase):
|
|||||||
|
|
||||||
|
|
||||||
class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
|
class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: Client,
|
session: Session,
|
||||||
first_share: list[str],
|
first_share: list[str],
|
||||||
second_share: list[str],
|
second_share: list[str],
|
||||||
):
|
):
|
||||||
super().__init__(client)
|
super().__init__(session.client)
|
||||||
self.first_share = first_share
|
self.first_share = first_share
|
||||||
self.second_share = second_share
|
self.second_share = second_share
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield from self.REC.confirm_recovery()
|
yield from self.REC.confirm_recovery()
|
||||||
@ -1994,19 +2001,21 @@ class InputFlowSlip39AdvancedRecoveryThresholdReached(InputFlowBase):
|
|||||||
yield from self.REC.warning_group_threshold_reached()
|
yield from self.REC.warning_group_threshold_reached()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
|
class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: Client,
|
session: Session,
|
||||||
first_share: list[str],
|
first_share: list[str],
|
||||||
second_share: list[str],
|
second_share: list[str],
|
||||||
):
|
):
|
||||||
super().__init__(client)
|
super().__init__(session.client)
|
||||||
self.first_share = first_share
|
self.first_share = first_share
|
||||||
self.second_share = second_share
|
self.second_share = second_share
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield from self.REC.confirm_recovery()
|
yield from self.REC.confirm_recovery()
|
||||||
@ -2018,7 +2027,7 @@ class InputFlowSlip39AdvancedRecoveryShareAlreadyEntered(InputFlowBase):
|
|||||||
yield from self.REC.warning_share_already_entered()
|
yield from self.REC.warning_share_already_entered()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase):
|
class InputFlowSlip39BasicRecoveryDryRun(InputFlowBase):
|
||||||
@ -2117,10 +2126,12 @@ class InputFlowSlip39BasicRecoveryNoAbort(InputFlowBase):
|
|||||||
|
|
||||||
|
|
||||||
class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
|
class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
|
||||||
def __init__(self, client: Client):
|
|
||||||
super().__init__(client)
|
def __init__(self, session: Session):
|
||||||
|
super().__init__(session.client)
|
||||||
self.first_invalid = ["slush"] * 20
|
self.first_invalid = ["slush"] * 20
|
||||||
self.second_invalid = ["slush"] * 33
|
self.second_invalid = ["slush"] * 33
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield from self.REC.confirm_recovery()
|
yield from self.REC.confirm_recovery()
|
||||||
@ -2132,16 +2143,18 @@ class InputFlowSlip39BasicRecoveryInvalidFirstShare(InputFlowBase):
|
|||||||
yield from self.REC.warning_invalid_recovery_share()
|
yield from self.REC.warning_invalid_recovery_share()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
|
class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
|
||||||
def __init__(self, client: Client, shares: list[str]):
|
|
||||||
super().__init__(client)
|
def __init__(self, session: Session, shares: list[str]):
|
||||||
|
super().__init__(session.client)
|
||||||
self.shares = shares
|
self.shares = shares
|
||||||
self.first_share = shares[0].split(" ")
|
self.first_share = shares[0].split(" ")
|
||||||
self.invalid_share = self.first_share[:3] + ["slush"] * 17
|
self.invalid_share = self.first_share[:3] + ["slush"] * 17
|
||||||
self.second_share = shares[1].split(" ")
|
self.second_share = shares[1].split(" ")
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield from self.REC.confirm_recovery()
|
yield from self.REC.confirm_recovery()
|
||||||
@ -2154,16 +2167,18 @@ class InputFlowSlip39BasicRecoveryInvalidSecondShare(InputFlowBase):
|
|||||||
yield from self.REC.success_more_shares_needed(1)
|
yield from self.REC.success_more_shares_needed(1)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
|
class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
|
||||||
def __init__(self, client: Client, share: list[str], nth_word: int):
|
|
||||||
super().__init__(client)
|
def __init__(self, session: Session, share: list[str], nth_word: int):
|
||||||
|
super().__init__(session.client)
|
||||||
self.share = share
|
self.share = share
|
||||||
self.nth_word = nth_word
|
self.nth_word = nth_word
|
||||||
# Invalid share - just enough words to trigger the warning
|
# Invalid share - just enough words to trigger the warning
|
||||||
self.modified_share = share[:nth_word] + [self.share[-1]]
|
self.modified_share = share[:nth_word] + [self.share[-1]]
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield from self.REC.confirm_recovery()
|
yield from self.REC.confirm_recovery()
|
||||||
@ -2174,15 +2189,17 @@ class InputFlowSlip39BasicRecoveryWrongNthWord(InputFlowBase):
|
|||||||
yield from self.REC.warning_share_from_another_shamir()
|
yield from self.REC.warning_share_from_another_shamir()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
|
class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
|
||||||
def __init__(self, client: Client, share: list[str]):
|
|
||||||
super().__init__(client)
|
def __init__(self, session: Session, share: list[str]):
|
||||||
|
super().__init__(session.client)
|
||||||
self.share = share
|
self.share = share
|
||||||
# Second duplicate share - only 4 words are needed to verify it
|
# Second duplicate share - only 4 words are needed to verify it
|
||||||
self.duplicate_share = self.share[:4]
|
self.duplicate_share = self.share[:4]
|
||||||
|
self.session = session
|
||||||
|
|
||||||
def input_flow_common(self) -> BRGeneratorType:
|
def input_flow_common(self) -> BRGeneratorType:
|
||||||
yield from self.REC.confirm_recovery()
|
yield from self.REC.confirm_recovery()
|
||||||
@ -2193,7 +2210,7 @@ class InputFlowSlip39BasicRecoverySameShare(InputFlowBase):
|
|||||||
yield from self.REC.warning_share_already_entered()
|
yield from self.REC.warning_share_already_entered()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
self.client.cancel()
|
self.session.cancel()
|
||||||
|
|
||||||
|
|
||||||
class InputFlowResetSkipBackup(InputFlowBase):
|
class InputFlowResetSkipBackup(InputFlowBase):
|
||||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from trezorlib import cosi, device, models
|
from trezorlib import cosi, device, models
|
||||||
from trezorlib._internal import translations
|
from trezorlib._internal import translations
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||||
|
|
||||||
from . import common
|
from . import common
|
||||||
|
|
||||||
@ -58,20 +58,19 @@ def sign_blob(blob: translations.TranslationsBlob) -> bytes:
|
|||||||
|
|
||||||
def build_and_sign_blob(
|
def build_and_sign_blob(
|
||||||
lang_or_def: translations.JsonDef | Path | str,
|
lang_or_def: translations.JsonDef | Path | str,
|
||||||
client: Client,
|
session: Session,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
blob = prepare_blob(lang_or_def, client.model, client.version)
|
blob = prepare_blob(lang_or_def, session.model, session.version)
|
||||||
return sign_blob(blob)
|
return sign_blob(blob)
|
||||||
|
|
||||||
|
|
||||||
def set_language(client: Client, lang: str):
|
def set_language(session: Session, lang: str):
|
||||||
if lang.startswith("en"):
|
if lang.startswith("en"):
|
||||||
language_data = b""
|
language_data = b""
|
||||||
else:
|
else:
|
||||||
language_data = build_and_sign_blob(lang, client)
|
language_data = build_and_sign_blob(lang, session)
|
||||||
with client:
|
with session:
|
||||||
device.change_language(client, language_data) # type: ignore
|
device.change_language(session, language_data) # type: ignore
|
||||||
_CURRENT_TRANSLATION.TR = TRANSLATIONS[lang]
|
|
||||||
|
|
||||||
|
|
||||||
def get_lang_json(lang: str) -> translations.JsonDef:
|
def get_lang_json(lang: str) -> translations.JsonDef:
|
||||||
|
Loading…
Reference in New Issue
Block a user