diff --git a/trezorlib/debuglink.py b/trezorlib/debuglink.py index 2987f6ea5f..de2609abff 100644 --- a/trezorlib/debuglink.py +++ b/trezorlib/debuglink.py @@ -26,8 +26,9 @@ EXPECTED_RESPONSES_CONTEXT_LINES = 3 class DebugLink: - def __init__(self, transport): + def __init__(self, transport, auto_interact=True): self.transport = transport + self.allow_interactions = auto_interact def open(self): self.transport.begin_session() @@ -87,6 +88,8 @@ class DebugLink: return obj.passphrase_protection def input(self, word=None, button=None, swipe=None): + if not self.allow_interactions: + return decision = proto.DebugLinkDecision() if button is not None: decision.yes_no = button @@ -130,6 +133,24 @@ class DebugLink: self._call(proto.DebugLinkFlashErase(sector=sector), nowait=True) +class NullDebugLink(DebugLink): + def __init__(self): + super().__init__(None) + + def open(self): + pass + + def close(self): + pass + + def _call(self, msg, nowait=False): + if not nowait: + if isinstance(msg, proto.DebugLinkGetState): + return proto.DebugLinkState() + else: + raise RuntimeError("unexpected call to a fake debuglink") + + class DebugUI: INPUT_FLOW_DONE = object() @@ -171,8 +192,16 @@ class TrezorClientDebugLink(TrezorClient): # without special DebugLink interface provided # by the device. - def __init__(self, transport): - self.debug = DebugLink(transport.find_debug()) + def __init__(self, transport, auto_interact=True): + try: + debug_transport = transport.find_debug() + self.debug = DebugLink(debug_transport, auto_interact) + except Exception: + if not auto_interact: + self.debug = NullDebugLink() + else: + raise + self.ui = DebugUI(self.debug) self.in_with_statement = 0 diff --git a/trezorlib/tests/device_tests/common.py b/trezorlib/tests/device_tests/common.py index 18dbb7d3e7..5c8a0a7131 100644 --- a/trezorlib/tests/device_tests/common.py +++ b/trezorlib/tests/device_tests/common.py @@ -15,7 +15,6 @@ # If not, see . from trezorlib import debuglink, device -from trezorlib.debuglink import TrezorClientDebugLink from trezorlib.messages.PassphraseSourceType import HOST as PASSPHRASE_ON_HOST from . import conftest @@ -35,8 +34,7 @@ class TrezorTest: pin8 = "45678978" def setup_method(self, method): - wirelink = conftest.get_device() - self.client = TrezorClientDebugLink(wirelink) + self.client = conftest.get_device() # self.client.set_buttonwait(3) device.wipe(self.client) diff --git a/trezorlib/tests/device_tests/conftest.py b/trezorlib/tests/device_tests/conftest.py index bfc5d5cf2d..7e82c7fe0b 100644 --- a/trezorlib/tests/device_tests/conftest.py +++ b/trezorlib/tests/device_tests/conftest.py @@ -30,20 +30,26 @@ TREZOR_VERSION = None def get_device(): path = os.environ.get("TREZOR_PATH") if path: - return get_transport(path) + transport = get_transport(path) else: devices = enumerate_devices() for device in devices: if hasattr(device, "find_debug"): - return device - raise RuntimeError("No debuggable device found") + transport = device + break + else: + raise RuntimeError("No debuggable device found") + env_interactive = int(os.environ.get("INTERACT", 0)) + try: + return TrezorClientDebugLink(transport, auto_interact=not env_interactive) + except Exception as e: + raise RuntimeError( + "Failed to open debuglink for {}".format(transport.get_path()) + ) from e def device_version(): - device = get_device() - if not device: - raise RuntimeError() - client = TrezorClientDebugLink(device) + client = get_device() if client.features.model == "T": return 2 else: @@ -52,8 +58,7 @@ def device_version(): @pytest.fixture(scope="function") def client(): - wirelink = get_device() - client = TrezorClientDebugLink(wirelink) + client = get_device() wipe_device(client) client.open() @@ -100,6 +105,11 @@ def pytest_addoption(parser): "args", [], ) + parser.addoption( + "--interactive", + action="store_true", + help="Wait for user to do interaction manually", + ) def pytest_runtest_setup(item):