diff --git a/tests/background.py b/tests/background.py new file mode 100644 index 0000000000..a8a5f3cec8 --- /dev/null +++ b/tests/background.py @@ -0,0 +1,53 @@ +from concurrent.futures import ThreadPoolExecutor + + +class NullUI: + @staticmethod + def button_request(code): + pass + + @staticmethod + def get_pin(code=None): + raise NotImplementedError("Should not be used with T1") + + @staticmethod + def get_passphrase(): + raise NotImplementedError("Should not be used with T1") + + +class BackgroundDeviceHandler: + _pool = ThreadPoolExecutor() + + def __init__(self, client): + self.client = client + self.client.ui = NullUI + self.task = None + + def run(self, function, *args, **kwargs): + if self.task is not None: + raise RuntimeError("Wait for previous task first") + self.task = self._pool.submit(function, self.client, *args, **kwargs) + + def result(self): + if self.task is None: + raise RuntimeError("No task running") + try: + return self.task.result() + finally: + self.task = None + + def features(self): + if self.task is not None: + raise RuntimeError("Cannot query features while task is running") + self.client.init_device() + return self.client.features + + def debuglink(self): + return self.client.debug + + def check_finalize(self): + if self.task is not None: + self.task.cancel() + self.task = None + return False + return True diff --git a/tests/click_tests/test_recovery.py b/tests/click_tests/test_recovery.py index 4e8b1b909c..ec596f0207 100644 --- a/tests/click_tests/test_recovery.py +++ b/tests/click_tests/test_recovery.py @@ -19,49 +19,46 @@ def click_ok(debug): @pytest.mark.skip_t1 @pytest.mark.setup_client(uninitialized=True) -def test_recovery(client): - with client: - client.set_expected_responses( - [ - messages.ButtonRequest(code=messages.ButtonRequestType.ProtectCall), - messages.Success(), - messages.Features(), - ] - ) - device.recover(client, pin_protection=False) +def test_recovery(device_handler): + features = device_handler.features() + debug = device_handler.debuglink() - assert client.features.initialized is False - assert client.features.recovery_mode is True + assert features.initialized is False + device_handler.run(device.recover, pin_protection=False) # select number of words - state = client.debug.state() - text = " ".join(state.layout_lines) + text = " ".join(debug.wait_layout()) + assert text.startswith("Recovery mode") + text = click_ok(debug) + assert "Select number of words" in text - text = click_ok(client.debug) + text = click_ok(debug) assert text == "WordSelector" # click "20" at 2, 2 coords = buttons.grid34(2, 2) - lines = client.debug.click(coords, wait=True) + lines = debug.click(coords, wait=True) text = " ".join(lines) expected_text = "Enter any share (20 words)" remaining = len(MNEMONIC_SLIP39_BASIC_20_3of6) for share in MNEMONIC_SLIP39_BASIC_20_3of6: assert expected_text in text - text = click_ok(client.debug) + text = click_ok(debug) assert text == "Slip39Keyboard" for word in share.split(" "): - text = enter_word(client.debug, word) + text = enter_word(debug, word) remaining -= 1 expected_text = "RecoveryHomescreen {} more".format(remaining) assert "You have successfully recovered your wallet" in text - text = click_ok(client.debug) + text = click_ok(debug) assert text == "Homescreen" - client.init_device() - assert client.features.initialized is True - assert client.features.recovery_mode is False + + assert isinstance(device_handler.result(), messages.Success) + features = device_handler.features() + assert features.initialized is True + assert features.recovery_mode is False diff --git a/tests/conftest.py b/tests/conftest.py index faa1ad2508..464b35722c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,8 @@ from trezorlib.device import apply_settings, wipe as wipe_device from trezorlib.messages.PassphraseSourceType import HOST as PASSPHRASE_ON_HOST from trezorlib.transport import enumerate_devices, get_transport +from .background import BackgroundDeviceHandler + def get_device(): path = os.environ.get("TREZOR_PATH") @@ -156,3 +158,23 @@ def pytest_runtest_setup(item): skip_altcoins = int(os.environ.get("TREZOR_PYTEST_SKIP_ALTCOINS", 0)) if item.get_closest_marker("altcoin") and skip_altcoins: pytest.skip("Skipping altcoin test") + + +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + # Make test results available in fixtures. + # See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures + outcome = yield + rep = outcome.get_result() + setattr(item, f"rep_{rep.when}", rep) + + +@pytest.fixture +def device_handler(client, request): + device_handler = BackgroundDeviceHandler(client) + yield device_handler + + # make sure all background tasks are done + finalized_ok = device_handler.check_finalize() + if request.node.rep_call.passed and not finalized_ok: + raise RuntimeError("Test did not check result of background task")