mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 14:28:07 +00:00
debug: improve infrastructure and expected message reporting
This commit is contained in:
parent
fc7a76e2f3
commit
c37bc9c38e
@ -27,7 +27,6 @@ from mnemonic import Mnemonic
|
||||
from . import (
|
||||
btc,
|
||||
cosi,
|
||||
debuglink,
|
||||
device,
|
||||
ethereum,
|
||||
exceptions,
|
||||
@ -96,12 +95,20 @@ class BaseClient(object):
|
||||
pass
|
||||
|
||||
def cancel(self):
|
||||
self.transport.write(proto.Cancel())
|
||||
self._raw_write(proto.Cancel())
|
||||
|
||||
@tools.session
|
||||
def call_raw(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
self._raw_write(msg)
|
||||
return self._raw_read()
|
||||
|
||||
def _raw_write(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
self.transport.write(msg)
|
||||
|
||||
def _raw_read(self):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
return self.transport.read()
|
||||
|
||||
def callback_PinMatrixRequest(self, msg):
|
||||
@ -115,7 +122,7 @@ class BaseClient(object):
|
||||
proto.FailureType.PinCancelled,
|
||||
proto.FailureType.PinExpected,
|
||||
):
|
||||
raise exceptions.PinException(msg.code, msg.message)
|
||||
raise exceptions.PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
@ -131,10 +138,11 @@ class BaseClient(object):
|
||||
return self.call_raw(proto.PassphraseStateAck())
|
||||
|
||||
def callback_ButtonRequest(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
self.transport.write(proto.ButtonAck())
|
||||
self._raw_write(proto.ButtonAck())
|
||||
self.ui.button_request(msg.code)
|
||||
return self.transport.read()
|
||||
return self._raw_read()
|
||||
|
||||
@tools.session
|
||||
def call(self, msg):
|
||||
|
@ -20,6 +20,7 @@ from mnemonic import Mnemonic
|
||||
from . import messages as proto, tools
|
||||
from .client import TrezorClient
|
||||
from .tools import expect
|
||||
from .protobuf import format_message
|
||||
|
||||
|
||||
class DebugLink:
|
||||
@ -126,15 +127,26 @@ class DebugLink:
|
||||
|
||||
|
||||
class DebugUI:
|
||||
INPUT_FLOW_DONE = object()
|
||||
|
||||
def __init__(self, debuglink: DebugLink):
|
||||
self.debuglink = debuglink
|
||||
self.pin = None
|
||||
self.passphrase = "sphinx of black quartz, judge my wov"
|
||||
self.input_flow = None
|
||||
|
||||
def button_request(self):
|
||||
self.debuglink.press_yes()
|
||||
def button_request(self, code):
|
||||
if self.input_flow is None:
|
||||
self.debuglink.press_yes()
|
||||
elif self.input_flow is self.INPUT_FLOW_DONE:
|
||||
raise AssertionError("input flow ended prematurely")
|
||||
else:
|
||||
try:
|
||||
self.input_flow.send(code)
|
||||
except StopIteration:
|
||||
self.input_flow = self.INPUT_FLOW_DONE
|
||||
|
||||
def get_pin(self):
|
||||
def get_pin(self, code=None):
|
||||
if self.pin:
|
||||
return self.pin
|
||||
else:
|
||||
@ -154,12 +166,10 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# of unit testing, because it will fail to work
|
||||
# without special DebugLink interface provided
|
||||
# by the device.
|
||||
DEBUG = LOG.getChild("debug_link").debug
|
||||
|
||||
def __init__(self, transport):
|
||||
self.debug = DebugLink(transport.find_debug())
|
||||
self.ui = DebugUI(self.debug)
|
||||
super().__init__(transport, self.ui)
|
||||
|
||||
self.in_with_statement = 0
|
||||
self.button_wait = 0
|
||||
@ -170,9 +180,11 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
# Do not expect any specific response from device
|
||||
self.expected_responses = None
|
||||
self.current_response = None
|
||||
|
||||
# Use blank passphrase
|
||||
self.set_passphrase("")
|
||||
super().__init__(transport, ui=self.ui)
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
@ -182,6 +194,14 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
def set_buttonwait(self, secs):
|
||||
self.button_wait = secs
|
||||
|
||||
def set_input_flow(self, input_flow):
|
||||
if callable(input_flow):
|
||||
input_flow = input_flow()
|
||||
if not hasattr(input_flow, "send"):
|
||||
raise RuntimeError("input_flow should be a generator function")
|
||||
self.ui.input_flow = input_flow
|
||||
next(input_flow) # can't send before first yield
|
||||
|
||||
def __enter__(self):
|
||||
# For usage in with/expected_responses
|
||||
self.in_with_statement += 1
|
||||
@ -196,20 +216,19 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
# return isinstance(value, TypeError)
|
||||
# Evaluate missed responses in 'with' statement
|
||||
if self.expected_responses is not None and len(self.expected_responses):
|
||||
raise RuntimeError(
|
||||
"Some of expected responses didn't come from device: %s"
|
||||
% [repr(x) for x in self.expected_responses]
|
||||
)
|
||||
if self.current_response < len(self.expected_responses):
|
||||
self._raise_unexpected_response(None)
|
||||
|
||||
# Cleanup
|
||||
self.expected_responses = None
|
||||
self.current_response = None
|
||||
return False
|
||||
|
||||
def set_expected_responses(self, expected):
|
||||
if not self.in_with_statement:
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
self.expected_responses = expected
|
||||
self.current_response = 0
|
||||
|
||||
def setup_debuglink(self, button, pin_correct):
|
||||
# self.button = button # True -> YES button, False -> NO button
|
||||
@ -224,7 +243,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
def set_mnemonic(self, mnemonic):
|
||||
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
||||
|
||||
def call_raw(self, msg):
|
||||
def _raw_read(self):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
# if SCREENSHOT and self.debug:
|
||||
@ -241,36 +260,63 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# im.save("scr%05d.png" % self.screenshot_id)
|
||||
# self.screenshot_id += 1
|
||||
|
||||
resp = super().call_raw(msg)
|
||||
resp = super()._raw_read()
|
||||
self._check_request(resp)
|
||||
return resp
|
||||
|
||||
def _check_request(self, msg):
|
||||
def _raise_unexpected_response(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
if self.expected_responses is not None:
|
||||
try:
|
||||
expected = self.expected_responses.pop(0)
|
||||
except IndexError:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Got %s, but no message has been expected" % repr(msg),
|
||||
output = []
|
||||
output.append("Expected responses:")
|
||||
for i, exp in enumerate(self.expected_responses):
|
||||
prefix = " " if i != self.current_response else ">>> "
|
||||
set_fields = {
|
||||
key: value
|
||||
for key, value in exp.__dict__.items()
|
||||
if value is not None and value != []
|
||||
}
|
||||
oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
|
||||
if len(oneline_str) < 60:
|
||||
output.append(
|
||||
"{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
|
||||
)
|
||||
else:
|
||||
output.append("{}{}(".format(prefix, exp.__class__.__name__))
|
||||
for key, value in set_fields.items():
|
||||
output.append("{} {}={!r}".format(prefix, key, value))
|
||||
output.append("{})".format(prefix))
|
||||
|
||||
if msg.__class__ != expected.__class__:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
||||
)
|
||||
output.append("")
|
||||
if msg is not None:
|
||||
output.append("Actually received:")
|
||||
output.append(format_message(msg))
|
||||
else:
|
||||
output.append("This message was never received.")
|
||||
raise AssertionError("\n".join(output))
|
||||
|
||||
for field, value in expected.__dict__.items():
|
||||
if value is None or value == []:
|
||||
continue
|
||||
if getattr(msg, field) != value:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
||||
)
|
||||
def _check_request(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
if self.expected_responses is None:
|
||||
return
|
||||
|
||||
if self.current_response >= len(self.expected_responses):
|
||||
raise AssertionError(
|
||||
"No more messages were expected, but we got:\n" + format_message(msg)
|
||||
)
|
||||
|
||||
expected = self.expected_responses[self.current_response]
|
||||
|
||||
if msg.__class__ != expected.__class__:
|
||||
self._raise_unexpected_response(msg)
|
||||
|
||||
for field, value in expected.__dict__.items():
|
||||
if value is None or value == []:
|
||||
continue
|
||||
if getattr(msg, field) != value:
|
||||
self._raise_unexpected_response(msg)
|
||||
|
||||
self.current_response += 1
|
||||
|
||||
def mnemonic_callback(self, _):
|
||||
word, pos = self.debug.read_recovery_word()
|
||||
|
@ -63,7 +63,7 @@ class TrezorTest:
|
||||
label="test",
|
||||
language="english",
|
||||
)
|
||||
if passphrase:
|
||||
if conftest.TREZOR_VERSION > 1 and passphrase:
|
||||
device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST)
|
||||
|
||||
def setup_mnemonic_allallall(self):
|
||||
|
@ -20,8 +20,9 @@ import os
|
||||
import pytest
|
||||
|
||||
from trezorlib import coins, log
|
||||
from trezorlib.client import TrezorClient, TrezorClientDebugLink
|
||||
from trezorlib.debuglink import TrezorClientDebugLink
|
||||
from trezorlib.transport import enumerate_devices, get_transport
|
||||
from trezorlib import device, debuglink
|
||||
|
||||
TREZOR_VERSION = None
|
||||
|
||||
@ -42,7 +43,7 @@ def device_version():
|
||||
device = get_device()
|
||||
if not device:
|
||||
raise RuntimeError()
|
||||
client = TrezorClient(device)
|
||||
client = TrezorClientDebugLink(device)
|
||||
if client.features.model == "T":
|
||||
return 2
|
||||
else:
|
||||
@ -52,11 +53,9 @@ def device_version():
|
||||
@pytest.fixture(scope="function")
|
||||
def client():
|
||||
wirelink = get_device()
|
||||
debuglink = wirelink.find_debug()
|
||||
client = TrezorClientDebugLink(wirelink)
|
||||
client.set_debuglink(debuglink)
|
||||
client.set_tx_api(coins.tx_api["Bitcoin"])
|
||||
client.wipe_device()
|
||||
device.wipe(client)
|
||||
client.transport.session_begin()
|
||||
|
||||
yield client
|
||||
@ -78,7 +77,8 @@ def setup_client(mnemonic=None, pin="", passphrase=False):
|
||||
def client_decorator(function):
|
||||
@functools.wraps(function)
|
||||
def wrapper(client, *args, **kwargs):
|
||||
client.load_device_by_mnemonic(
|
||||
debuglink.load_device_by_mnemonic(
|
||||
client,
|
||||
mnemonic=mnemonic,
|
||||
pin=pin,
|
||||
passphrase_protection=passphrase,
|
||||
|
@ -188,6 +188,7 @@ class expect:
|
||||
def __call__(self, f):
|
||||
@functools.wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
ret = f(*args, **kwargs)
|
||||
if not isinstance(ret, self.expected):
|
||||
raise RuntimeError(
|
||||
|
Loading…
Reference in New Issue
Block a user