debug: improve infrastructure and expected message reporting

pull/25/head
matejcik 6 years ago
parent fc7a76e2f3
commit c37bc9c38e

@ -27,7 +27,6 @@ from mnemonic import Mnemonic
from . import (
btc,
cosi,
debuglink,
device,
ethereum,
exceptions,
@ -96,12 +95,20 @@ class BaseClient(object):
pass
def cancel(self):
self.transport.write(proto.Cancel())
self._raw_write(proto.Cancel())
@tools.session
def call_raw(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self._raw_write(msg)
return self._raw_read()
def _raw_write(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self.transport.write(msg)
def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
return self.transport.read()
def callback_PinMatrixRequest(self, msg):
@ -115,7 +122,7 @@ class BaseClient(object):
proto.FailureType.PinCancelled,
proto.FailureType.PinExpected,
):
raise exceptions.PinException(msg.code, msg.message)
raise exceptions.PinException(resp.code, resp.message)
else:
return resp
@ -131,10 +138,11 @@ class BaseClient(object):
return self.call_raw(proto.PassphraseStateAck())
def callback_ButtonRequest(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self.transport.write(proto.ButtonAck())
self._raw_write(proto.ButtonAck())
self.ui.button_request(msg.code)
return self.transport.read()
return self._raw_read()
@tools.session
def call(self, msg):

@ -20,6 +20,7 @@ from mnemonic import Mnemonic
from . import messages as proto, tools
from .client import TrezorClient
from .tools import expect
from .protobuf import format_message
class DebugLink:
@ -126,15 +127,26 @@ class DebugLink:
class DebugUI:
INPUT_FLOW_DONE = object()
def __init__(self, debuglink: DebugLink):
self.debuglink = debuglink
self.pin = None
self.passphrase = "sphinx of black quartz, judge my wov"
self.input_flow = None
def button_request(self):
self.debuglink.press_yes()
def button_request(self, code):
if self.input_flow is None:
self.debuglink.press_yes()
elif self.input_flow is self.INPUT_FLOW_DONE:
raise AssertionError("input flow ended prematurely")
else:
try:
self.input_flow.send(code)
except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self):
def get_pin(self, code=None):
if self.pin:
return self.pin
else:
@ -154,12 +166,10 @@ class TrezorClientDebugLink(TrezorClient):
# of unit testing, because it will fail to work
# without special DebugLink interface provided
# by the device.
DEBUG = LOG.getChild("debug_link").debug
def __init__(self, transport):
self.debug = DebugLink(transport.find_debug())
self.ui = DebugUI(self.debug)
super().__init__(transport, self.ui)
self.in_with_statement = 0
self.button_wait = 0
@ -170,9 +180,11 @@ class TrezorClientDebugLink(TrezorClient):
# Do not expect any specific response from device
self.expected_responses = None
self.current_response = None
# Use blank passphrase
self.set_passphrase("")
super().__init__(transport, ui=self.ui)
def close(self):
super().close()
@ -182,6 +194,14 @@ class TrezorClientDebugLink(TrezorClient):
def set_buttonwait(self, secs):
self.button_wait = secs
def set_input_flow(self, input_flow):
if callable(input_flow):
input_flow = input_flow()
if not hasattr(input_flow, "send"):
raise RuntimeError("input_flow should be a generator function")
self.ui.input_flow = input_flow
next(input_flow) # can't send before first yield
def __enter__(self):
# For usage in with/expected_responses
self.in_with_statement += 1
@ -196,20 +216,19 @@ class TrezorClientDebugLink(TrezorClient):
# return isinstance(value, TypeError)
# Evaluate missed responses in 'with' statement
if self.expected_responses is not None and len(self.expected_responses):
raise RuntimeError(
"Some of expected responses didn't come from device: %s"
% [repr(x) for x in self.expected_responses]
)
if self.current_response < len(self.expected_responses):
self._raise_unexpected_response(None)
# Cleanup
self.expected_responses = None
self.current_response = None
return False
def set_expected_responses(self, expected):
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected
self.current_response = 0
def setup_debuglink(self, button, pin_correct):
# self.button = button # True -> YES button, False -> NO button
@ -224,7 +243,7 @@ class TrezorClientDebugLink(TrezorClient):
def set_mnemonic(self, mnemonic):
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def call_raw(self, msg):
def _raw_read(self):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# if SCREENSHOT and self.debug:
@ -241,36 +260,63 @@ class TrezorClientDebugLink(TrezorClient):
# im.save("scr%05d.png" % self.screenshot_id)
# self.screenshot_id += 1
resp = super().call_raw(msg)
resp = super()._raw_read()
self._check_request(resp)
return resp
def _check_request(self, msg):
def _raise_unexpected_response(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if self.expected_responses is not None:
try:
expected = self.expected_responses.pop(0)
except IndexError:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Got %s, but no message has been expected" % repr(msg),
output = []
output.append("Expected responses:")
for i, exp in enumerate(self.expected_responses):
prefix = " " if i != self.current_response else ">>> "
set_fields = {
key: value
for key, value in exp.__dict__.items()
if value is not None and value != []
}
oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
if len(oneline_str) < 60:
output.append(
"{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
)
else:
output.append("{}{}(".format(prefix, exp.__class__.__name__))
for key, value in set_fields.items():
output.append("{} {}={!r}".format(prefix, key, value))
output.append("{})".format(prefix))
output.append("")
if msg is not None:
output.append("Actually received:")
output.append(format_message(msg))
else:
output.append("This message was never received.")
raise AssertionError("\n".join(output))
if msg.__class__ != expected.__class__:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
def _check_request(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if self.expected_responses is None:
return
if self.current_response >= len(self.expected_responses):
raise AssertionError(
"No more messages were expected, but we got:\n" + format_message(msg)
)
expected = self.expected_responses[self.current_response]
if msg.__class__ != expected.__class__:
self._raise_unexpected_response(msg)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
self._raise_unexpected_response(msg)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
self.current_response += 1
def mnemonic_callback(self, _):
word, pos = self.debug.read_recovery_word()

@ -63,7 +63,7 @@ class TrezorTest:
label="test",
language="english",
)
if passphrase:
if conftest.TREZOR_VERSION > 1 and passphrase:
device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST)
def setup_mnemonic_allallall(self):

@ -20,8 +20,9 @@ import os
import pytest
from trezorlib import coins, log
from trezorlib.client import TrezorClient, TrezorClientDebugLink
from trezorlib.debuglink import TrezorClientDebugLink
from trezorlib.transport import enumerate_devices, get_transport
from trezorlib import device, debuglink
TREZOR_VERSION = None
@ -42,7 +43,7 @@ def device_version():
device = get_device()
if not device:
raise RuntimeError()
client = TrezorClient(device)
client = TrezorClientDebugLink(device)
if client.features.model == "T":
return 2
else:
@ -52,11 +53,9 @@ def device_version():
@pytest.fixture(scope="function")
def client():
wirelink = get_device()
debuglink = wirelink.find_debug()
client = TrezorClientDebugLink(wirelink)
client.set_debuglink(debuglink)
client.set_tx_api(coins.tx_api["Bitcoin"])
client.wipe_device()
device.wipe(client)
client.transport.session_begin()
yield client
@ -78,7 +77,8 @@ def setup_client(mnemonic=None, pin="", passphrase=False):
def client_decorator(function):
@functools.wraps(function)
def wrapper(client, *args, **kwargs):
client.load_device_by_mnemonic(
debuglink.load_device_by_mnemonic(
client,
mnemonic=mnemonic,
pin=pin,
passphrase_protection=passphrase,

@ -188,6 +188,7 @@ class expect:
def __call__(self, f):
@functools.wraps(f)
def wrapped_f(*args, **kwargs):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
ret = f(*args, **kwargs)
if not isinstance(ret, self.expected):
raise RuntimeError(

Loading…
Cancel
Save