1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

debug: improve infrastructure and expected message reporting

This commit is contained in:
matejcik 2018-10-02 17:18:13 +02:00
parent fc7a76e2f3
commit c37bc9c38e
5 changed files with 100 additions and 45 deletions

View File

@ -27,7 +27,6 @@ from mnemonic import Mnemonic
from . import ( from . import (
btc, btc,
cosi, cosi,
debuglink,
device, device,
ethereum, ethereum,
exceptions, exceptions,
@ -96,12 +95,20 @@ class BaseClient(object):
pass pass
def cancel(self): def cancel(self):
self.transport.write(proto.Cancel()) self._raw_write(proto.Cancel())
@tools.session @tools.session
def call_raw(self, msg): def call_raw(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()
def _raw_write(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
self.transport.write(msg) self.transport.write(msg)
def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
return self.transport.read() return self.transport.read()
def callback_PinMatrixRequest(self, msg): def callback_PinMatrixRequest(self, msg):
@ -115,7 +122,7 @@ class BaseClient(object):
proto.FailureType.PinCancelled, proto.FailureType.PinCancelled,
proto.FailureType.PinExpected, proto.FailureType.PinExpected,
): ):
raise exceptions.PinException(msg.code, msg.message) raise exceptions.PinException(resp.code, resp.message)
else: else:
return resp return resp
@ -131,10 +138,11 @@ class BaseClient(object):
return self.call_raw(proto.PassphraseStateAck()) return self.call_raw(proto.PassphraseStateAck())
def callback_ButtonRequest(self, msg): def callback_ButtonRequest(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later # do this raw - send ButtonAck first, notify UI later
self.transport.write(proto.ButtonAck()) self._raw_write(proto.ButtonAck())
self.ui.button_request(msg.code) self.ui.button_request(msg.code)
return self.transport.read() return self._raw_read()
@tools.session @tools.session
def call(self, msg): def call(self, msg):

View File

@ -20,6 +20,7 @@ from mnemonic import Mnemonic
from . import messages as proto, tools from . import messages as proto, tools
from .client import TrezorClient from .client import TrezorClient
from .tools import expect from .tools import expect
from .protobuf import format_message
class DebugLink: class DebugLink:
@ -126,15 +127,26 @@ class DebugLink:
class DebugUI: class DebugUI:
INPUT_FLOW_DONE = object()
def __init__(self, debuglink: DebugLink): def __init__(self, debuglink: DebugLink):
self.debuglink = debuglink self.debuglink = debuglink
self.pin = None self.pin = None
self.passphrase = "sphinx of black quartz, judge my wov" self.passphrase = "sphinx of black quartz, judge my wov"
self.input_flow = None
def button_request(self): def button_request(self, code):
self.debuglink.press_yes() if self.input_flow is None:
self.debuglink.press_yes()
elif self.input_flow is self.INPUT_FLOW_DONE:
raise AssertionError("input flow ended prematurely")
else:
try:
self.input_flow.send(code)
except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self): def get_pin(self, code=None):
if self.pin: if self.pin:
return self.pin return self.pin
else: else:
@ -154,12 +166,10 @@ class TrezorClientDebugLink(TrezorClient):
# of unit testing, because it will fail to work # of unit testing, because it will fail to work
# without special DebugLink interface provided # without special DebugLink interface provided
# by the device. # by the device.
DEBUG = LOG.getChild("debug_link").debug
def __init__(self, transport): def __init__(self, transport):
self.debug = DebugLink(transport.find_debug()) self.debug = DebugLink(transport.find_debug())
self.ui = DebugUI(self.debug) self.ui = DebugUI(self.debug)
super().__init__(transport, self.ui)
self.in_with_statement = 0 self.in_with_statement = 0
self.button_wait = 0 self.button_wait = 0
@ -170,9 +180,11 @@ class TrezorClientDebugLink(TrezorClient):
# Do not expect any specific response from device # Do not expect any specific response from device
self.expected_responses = None self.expected_responses = None
self.current_response = None
# Use blank passphrase # Use blank passphrase
self.set_passphrase("") self.set_passphrase("")
super().__init__(transport, ui=self.ui)
def close(self): def close(self):
super().close() super().close()
@ -182,6 +194,14 @@ class TrezorClientDebugLink(TrezorClient):
def set_buttonwait(self, secs): def set_buttonwait(self, secs):
self.button_wait = secs self.button_wait = secs
def set_input_flow(self, input_flow):
if callable(input_flow):
input_flow = input_flow()
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow
next(input_flow) # can't send before first yield
def __enter__(self): def __enter__(self):
# For usage in with/expected_responses # For usage in with/expected_responses
self.in_with_statement += 1 self.in_with_statement += 1
@ -196,20 +216,19 @@ class TrezorClientDebugLink(TrezorClient):
# return isinstance(value, TypeError) # return isinstance(value, TypeError)
# Evaluate missed responses in 'with' statement # Evaluate missed responses in 'with' statement
if self.expected_responses is not None and len(self.expected_responses): if self.current_response < len(self.expected_responses):
raise RuntimeError( self._raise_unexpected_response(None)
"Some of expected responses didn't come from device: %s"
% [repr(x) for x in self.expected_responses]
)
# Cleanup # Cleanup
self.expected_responses = None self.expected_responses = None
self.current_response = None
return False return False
def set_expected_responses(self, expected): def set_expected_responses(self, expected):
if not self.in_with_statement: if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement") raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected self.expected_responses = expected
self.current_response = 0
def setup_debuglink(self, button, pin_correct): def setup_debuglink(self, button, pin_correct):
# self.button = button # True -> YES button, False -> NO button # self.button = button # True -> YES button, False -> NO button
@ -224,7 +243,7 @@ class TrezorClientDebugLink(TrezorClient):
def set_mnemonic(self, mnemonic): def set_mnemonic(self, mnemonic):
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def call_raw(self, msg): def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
# if SCREENSHOT and self.debug: # if SCREENSHOT and self.debug:
@ -241,36 +260,63 @@ class TrezorClientDebugLink(TrezorClient):
# im.save("scr%05d.png" % self.screenshot_id) # im.save("scr%05d.png" % self.screenshot_id)
# self.screenshot_id += 1 # self.screenshot_id += 1
resp = super().call_raw(msg) resp = super()._raw_read()
self._check_request(resp) self._check_request(resp)
return resp return resp
def _check_request(self, msg): def _raise_unexpected_response(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
if self.expected_responses is not None: output = []
try: output.append("Expected responses:")
expected = self.expected_responses.pop(0) for i, exp in enumerate(self.expected_responses):
except IndexError: prefix = " " if i != self.current_response else ">>> "
raise AssertionError( set_fields = {
proto.FailureType.UnexpectedMessage, key: value
"Got %s, but no message has been expected" % repr(msg), for key, value in exp.__dict__.items()
if value is not None and value != []
}
oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
if len(oneline_str) < 60:
output.append(
"{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
) )
else:
output.append("{}{}(".format(prefix, exp.__class__.__name__))
for key, value in set_fields.items():
output.append("{} {}={!r}".format(prefix, key, value))
output.append("{})".format(prefix))
if msg.__class__ != expected.__class__: output.append("")
raise AssertionError( if msg is not None:
proto.FailureType.UnexpectedMessage, output.append("Actually received:")
"Expected %s, got %s" % (repr(expected), repr(msg)), output.append(format_message(msg))
) else:
output.append("This message was never received.")
raise AssertionError("\n".join(output))
for field, value in expected.__dict__.items(): def _check_request(self, msg):
if value is None or value == []: __tracebackhide__ = True # for pytest # pylint: disable=W0612
continue if self.expected_responses is None:
if getattr(msg, field) != value: return
raise AssertionError(
proto.FailureType.UnexpectedMessage, if self.current_response >= len(self.expected_responses):
"Expected %s, got %s" % (repr(expected), repr(msg)), raise AssertionError(
) "No more messages were expected, but we got:\n" + format_message(msg)
)
expected = self.expected_responses[self.current_response]
if msg.__class__ != expected.__class__:
self._raise_unexpected_response(msg)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
self._raise_unexpected_response(msg)
self.current_response += 1
def mnemonic_callback(self, _): def mnemonic_callback(self, _):
word, pos = self.debug.read_recovery_word() word, pos = self.debug.read_recovery_word()

View File

@ -63,7 +63,7 @@ class TrezorTest:
label="test", label="test",
language="english", language="english",
) )
if passphrase: if conftest.TREZOR_VERSION > 1 and passphrase:
device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST) device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST)
def setup_mnemonic_allallall(self): def setup_mnemonic_allallall(self):

View File

@ -20,8 +20,9 @@ import os
import pytest import pytest
from trezorlib import coins, log from trezorlib import coins, log
from trezorlib.client import TrezorClient, TrezorClientDebugLink from trezorlib.debuglink import TrezorClientDebugLink
from trezorlib.transport import enumerate_devices, get_transport from trezorlib.transport import enumerate_devices, get_transport
from trezorlib import device, debuglink
TREZOR_VERSION = None TREZOR_VERSION = None
@ -42,7 +43,7 @@ def device_version():
device = get_device() device = get_device()
if not device: if not device:
raise RuntimeError() raise RuntimeError()
client = TrezorClient(device) client = TrezorClientDebugLink(device)
if client.features.model == "T": if client.features.model == "T":
return 2 return 2
else: else:
@ -52,11 +53,9 @@ def device_version():
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client(): def client():
wirelink = get_device() wirelink = get_device()
debuglink = wirelink.find_debug()
client = TrezorClientDebugLink(wirelink) client = TrezorClientDebugLink(wirelink)
client.set_debuglink(debuglink)
client.set_tx_api(coins.tx_api["Bitcoin"]) client.set_tx_api(coins.tx_api["Bitcoin"])
client.wipe_device() device.wipe(client)
client.transport.session_begin() client.transport.session_begin()
yield client yield client
@ -78,7 +77,8 @@ def setup_client(mnemonic=None, pin="", passphrase=False):
def client_decorator(function): def client_decorator(function):
@functools.wraps(function) @functools.wraps(function)
def wrapper(client, *args, **kwargs): def wrapper(client, *args, **kwargs):
client.load_device_by_mnemonic( debuglink.load_device_by_mnemonic(
client,
mnemonic=mnemonic, mnemonic=mnemonic,
pin=pin, pin=pin,
passphrase_protection=passphrase, passphrase_protection=passphrase,

View File

@ -188,6 +188,7 @@ class expect:
def __call__(self, f): def __call__(self, f):
@functools.wraps(f) @functools.wraps(f)
def wrapped_f(*args, **kwargs): def wrapped_f(*args, **kwargs):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
ret = f(*args, **kwargs) ret = f(*args, **kwargs)
if not isinstance(ret, self.expected): if not isinstance(ret, self.expected):
raise RuntimeError( raise RuntimeError(