From c37bc9c38eb405ed3757e53cd5ecec93caf45af9 Mon Sep 17 00:00:00 2001 From: matejcik Date: Tue, 2 Oct 2018 17:18:13 +0200 Subject: [PATCH] debug: improve infrastructure and expected message reporting --- trezorlib/client.py | 18 +++- trezorlib/debuglink.py | 112 ++++++++++++++++------- trezorlib/tests/device_tests/common.py | 2 +- trezorlib/tests/device_tests/conftest.py | 12 +-- trezorlib/tools.py | 1 + 5 files changed, 100 insertions(+), 45 deletions(-) diff --git a/trezorlib/client.py b/trezorlib/client.py index f44b5c5702..3f17b9bd16 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -27,7 +27,6 @@ from mnemonic import Mnemonic from . import ( btc, cosi, - debuglink, device, ethereum, exceptions, @@ -96,12 +95,20 @@ class BaseClient(object): pass def cancel(self): - self.transport.write(proto.Cancel()) + self._raw_write(proto.Cancel()) @tools.session def call_raw(self, msg): + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + self._raw_write(msg) + return self._raw_read() + + def _raw_write(self, msg): __tracebackhide__ = True # for pytest # pylint: disable=W0612 self.transport.write(msg) + + def _raw_read(self): + __tracebackhide__ = True # for pytest # pylint: disable=W0612 return self.transport.read() def callback_PinMatrixRequest(self, msg): @@ -115,7 +122,7 @@ class BaseClient(object): proto.FailureType.PinCancelled, proto.FailureType.PinExpected, ): - raise exceptions.PinException(msg.code, msg.message) + raise exceptions.PinException(resp.code, resp.message) else: return resp @@ -131,10 +138,11 @@ class BaseClient(object): return self.call_raw(proto.PassphraseStateAck()) def callback_ButtonRequest(self, msg): + __tracebackhide__ = True # for pytest # pylint: disable=W0612 # do this raw - send ButtonAck first, notify UI later - self.transport.write(proto.ButtonAck()) + self._raw_write(proto.ButtonAck()) self.ui.button_request(msg.code) - return self.transport.read() + return self._raw_read() @tools.session def call(self, msg): diff --git a/trezorlib/debuglink.py b/trezorlib/debuglink.py index dc11703956..0549f78664 100644 --- a/trezorlib/debuglink.py +++ b/trezorlib/debuglink.py @@ -20,6 +20,7 @@ from mnemonic import Mnemonic from . import messages as proto, tools from .client import TrezorClient from .tools import expect +from .protobuf import format_message class DebugLink: @@ -126,15 +127,26 @@ class DebugLink: class DebugUI: + INPUT_FLOW_DONE = object() + def __init__(self, debuglink: DebugLink): self.debuglink = debuglink self.pin = None self.passphrase = "sphinx of black quartz, judge my wov" + self.input_flow = None - def button_request(self): - self.debuglink.press_yes() + def button_request(self, code): + if self.input_flow is None: + self.debuglink.press_yes() + elif self.input_flow is self.INPUT_FLOW_DONE: + raise AssertionError("input flow ended prematurely") + else: + try: + self.input_flow.send(code) + except StopIteration: + self.input_flow = self.INPUT_FLOW_DONE - def get_pin(self): + def get_pin(self, code=None): if self.pin: return self.pin else: @@ -154,12 +166,10 @@ class TrezorClientDebugLink(TrezorClient): # of unit testing, because it will fail to work # without special DebugLink interface provided # by the device. - DEBUG = LOG.getChild("debug_link").debug def __init__(self, transport): self.debug = DebugLink(transport.find_debug()) self.ui = DebugUI(self.debug) - super().__init__(transport, self.ui) self.in_with_statement = 0 self.button_wait = 0 @@ -170,9 +180,11 @@ class TrezorClientDebugLink(TrezorClient): # Do not expect any specific response from device self.expected_responses = None + self.current_response = None # Use blank passphrase self.set_passphrase("") + super().__init__(transport, ui=self.ui) def close(self): super().close() @@ -182,6 +194,14 @@ class TrezorClientDebugLink(TrezorClient): def set_buttonwait(self, secs): self.button_wait = secs + def set_input_flow(self, input_flow): + if callable(input_flow): + input_flow = input_flow() + if not hasattr(input_flow, "send"): + raise RuntimeError("input_flow should be a generator function") + self.ui.input_flow = input_flow + next(input_flow) # can't send before first yield + def __enter__(self): # For usage in with/expected_responses self.in_with_statement += 1 @@ -196,20 +216,19 @@ class TrezorClientDebugLink(TrezorClient): # return isinstance(value, TypeError) # Evaluate missed responses in 'with' statement - if self.expected_responses is not None and len(self.expected_responses): - raise RuntimeError( - "Some of expected responses didn't come from device: %s" - % [repr(x) for x in self.expected_responses] - ) + if self.current_response < len(self.expected_responses): + self._raise_unexpected_response(None) # Cleanup self.expected_responses = None + self.current_response = None return False def set_expected_responses(self, expected): if not self.in_with_statement: raise RuntimeError("Must be called inside 'with' statement") self.expected_responses = expected + self.current_response = 0 def setup_debuglink(self, button, pin_correct): # self.button = button # True -> YES button, False -> NO button @@ -224,7 +243,7 @@ class TrezorClientDebugLink(TrezorClient): def set_mnemonic(self, mnemonic): self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - def call_raw(self, msg): + def _raw_read(self): __tracebackhide__ = True # for pytest # pylint: disable=W0612 # if SCREENSHOT and self.debug: @@ -241,36 +260,63 @@ class TrezorClientDebugLink(TrezorClient): # im.save("scr%05d.png" % self.screenshot_id) # self.screenshot_id += 1 - resp = super().call_raw(msg) + resp = super()._raw_read() self._check_request(resp) return resp - def _check_request(self, msg): + def _raise_unexpected_response(self, msg): __tracebackhide__ = True # for pytest # pylint: disable=W0612 - if self.expected_responses is not None: - try: - expected = self.expected_responses.pop(0) - except IndexError: - raise AssertionError( - proto.FailureType.UnexpectedMessage, - "Got %s, but no message has been expected" % repr(msg), + output = [] + output.append("Expected responses:") + for i, exp in enumerate(self.expected_responses): + prefix = " " if i != self.current_response else ">>> " + set_fields = { + key: value + for key, value in exp.__dict__.items() + if value is not None and value != [] + } + oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items()) + if len(oneline_str) < 60: + output.append( + "{}{}({})".format(prefix, exp.__class__.__name__, oneline_str) ) + else: + output.append("{}{}(".format(prefix, exp.__class__.__name__)) + for key, value in set_fields.items(): + output.append("{} {}={!r}".format(prefix, key, value)) + output.append("{})".format(prefix)) - if msg.__class__ != expected.__class__: - raise AssertionError( - proto.FailureType.UnexpectedMessage, - "Expected %s, got %s" % (repr(expected), repr(msg)), - ) + output.append("") + if msg is not None: + output.append("Actually received:") + output.append(format_message(msg)) + else: + output.append("This message was never received.") + raise AssertionError("\n".join(output)) - for field, value in expected.__dict__.items(): - if value is None or value == []: - continue - if getattr(msg, field) != value: - raise AssertionError( - proto.FailureType.UnexpectedMessage, - "Expected %s, got %s" % (repr(expected), repr(msg)), - ) + def _check_request(self, msg): + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + if self.expected_responses is None: + return + + if self.current_response >= len(self.expected_responses): + raise AssertionError( + "No more messages were expected, but we got:\n" + format_message(msg) + ) + + expected = self.expected_responses[self.current_response] + + if msg.__class__ != expected.__class__: + self._raise_unexpected_response(msg) + + for field, value in expected.__dict__.items(): + if value is None or value == []: + continue + if getattr(msg, field) != value: + self._raise_unexpected_response(msg) + + self.current_response += 1 def mnemonic_callback(self, _): word, pos = self.debug.read_recovery_word() diff --git a/trezorlib/tests/device_tests/common.py b/trezorlib/tests/device_tests/common.py index e905c11ff7..2877b63ad0 100644 --- a/trezorlib/tests/device_tests/common.py +++ b/trezorlib/tests/device_tests/common.py @@ -63,7 +63,7 @@ class TrezorTest: label="test", language="english", ) - if passphrase: + if conftest.TREZOR_VERSION > 1 and passphrase: device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST) def setup_mnemonic_allallall(self): diff --git a/trezorlib/tests/device_tests/conftest.py b/trezorlib/tests/device_tests/conftest.py index 25d84c59d8..cb750e2833 100644 --- a/trezorlib/tests/device_tests/conftest.py +++ b/trezorlib/tests/device_tests/conftest.py @@ -20,8 +20,9 @@ import os import pytest from trezorlib import coins, log -from trezorlib.client import TrezorClient, TrezorClientDebugLink +from trezorlib.debuglink import TrezorClientDebugLink from trezorlib.transport import enumerate_devices, get_transport +from trezorlib import device, debuglink TREZOR_VERSION = None @@ -42,7 +43,7 @@ def device_version(): device = get_device() if not device: raise RuntimeError() - client = TrezorClient(device) + client = TrezorClientDebugLink(device) if client.features.model == "T": return 2 else: @@ -52,11 +53,9 @@ def device_version(): @pytest.fixture(scope="function") def client(): wirelink = get_device() - debuglink = wirelink.find_debug() client = TrezorClientDebugLink(wirelink) - client.set_debuglink(debuglink) client.set_tx_api(coins.tx_api["Bitcoin"]) - client.wipe_device() + device.wipe(client) client.transport.session_begin() yield client @@ -78,7 +77,8 @@ def setup_client(mnemonic=None, pin="", passphrase=False): def client_decorator(function): @functools.wraps(function) def wrapper(client, *args, **kwargs): - client.load_device_by_mnemonic( + debuglink.load_device_by_mnemonic( + client, mnemonic=mnemonic, pin=pin, passphrase_protection=passphrase, diff --git a/trezorlib/tools.py b/trezorlib/tools.py index 264d1092a2..72f8a01ece 100644 --- a/trezorlib/tools.py +++ b/trezorlib/tools.py @@ -188,6 +188,7 @@ class expect: def __call__(self, f): @functools.wraps(f) def wrapped_f(*args, **kwargs): + __tracebackhide__ = True # for pytest # pylint: disable=W0612 ret = f(*args, **kwargs) if not isinstance(ret, self.expected): raise RuntimeError(