mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-08 22:40:59 +00:00
debug: improve infrastructure and expected message reporting
This commit is contained in:
parent
fc7a76e2f3
commit
c37bc9c38e
@ -27,7 +27,6 @@ from mnemonic import Mnemonic
|
|||||||
from . import (
|
from . import (
|
||||||
btc,
|
btc,
|
||||||
cosi,
|
cosi,
|
||||||
debuglink,
|
|
||||||
device,
|
device,
|
||||||
ethereum,
|
ethereum,
|
||||||
exceptions,
|
exceptions,
|
||||||
@ -96,12 +95,20 @@ class BaseClient(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
self.transport.write(proto.Cancel())
|
self._raw_write(proto.Cancel())
|
||||||
|
|
||||||
@tools.session
|
@tools.session
|
||||||
def call_raw(self, msg):
|
def call_raw(self, msg):
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
self._raw_write(msg)
|
||||||
|
return self._raw_read()
|
||||||
|
|
||||||
|
def _raw_write(self, msg):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
self.transport.write(msg)
|
self.transport.write(msg)
|
||||||
|
|
||||||
|
def _raw_read(self):
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
return self.transport.read()
|
return self.transport.read()
|
||||||
|
|
||||||
def callback_PinMatrixRequest(self, msg):
|
def callback_PinMatrixRequest(self, msg):
|
||||||
@ -115,7 +122,7 @@ class BaseClient(object):
|
|||||||
proto.FailureType.PinCancelled,
|
proto.FailureType.PinCancelled,
|
||||||
proto.FailureType.PinExpected,
|
proto.FailureType.PinExpected,
|
||||||
):
|
):
|
||||||
raise exceptions.PinException(msg.code, msg.message)
|
raise exceptions.PinException(resp.code, resp.message)
|
||||||
else:
|
else:
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@ -131,10 +138,11 @@ class BaseClient(object):
|
|||||||
return self.call_raw(proto.PassphraseStateAck())
|
return self.call_raw(proto.PassphraseStateAck())
|
||||||
|
|
||||||
def callback_ButtonRequest(self, msg):
|
def callback_ButtonRequest(self, msg):
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
# do this raw - send ButtonAck first, notify UI later
|
# do this raw - send ButtonAck first, notify UI later
|
||||||
self.transport.write(proto.ButtonAck())
|
self._raw_write(proto.ButtonAck())
|
||||||
self.ui.button_request(msg.code)
|
self.ui.button_request(msg.code)
|
||||||
return self.transport.read()
|
return self._raw_read()
|
||||||
|
|
||||||
@tools.session
|
@tools.session
|
||||||
def call(self, msg):
|
def call(self, msg):
|
||||||
|
@ -20,6 +20,7 @@ from mnemonic import Mnemonic
|
|||||||
from . import messages as proto, tools
|
from . import messages as proto, tools
|
||||||
from .client import TrezorClient
|
from .client import TrezorClient
|
||||||
from .tools import expect
|
from .tools import expect
|
||||||
|
from .protobuf import format_message
|
||||||
|
|
||||||
|
|
||||||
class DebugLink:
|
class DebugLink:
|
||||||
@ -126,15 +127,26 @@ class DebugLink:
|
|||||||
|
|
||||||
|
|
||||||
class DebugUI:
|
class DebugUI:
|
||||||
|
INPUT_FLOW_DONE = object()
|
||||||
|
|
||||||
def __init__(self, debuglink: DebugLink):
|
def __init__(self, debuglink: DebugLink):
|
||||||
self.debuglink = debuglink
|
self.debuglink = debuglink
|
||||||
self.pin = None
|
self.pin = None
|
||||||
self.passphrase = "sphinx of black quartz, judge my wov"
|
self.passphrase = "sphinx of black quartz, judge my wov"
|
||||||
|
self.input_flow = None
|
||||||
|
|
||||||
def button_request(self):
|
def button_request(self, code):
|
||||||
|
if self.input_flow is None:
|
||||||
self.debuglink.press_yes()
|
self.debuglink.press_yes()
|
||||||
|
elif self.input_flow is self.INPUT_FLOW_DONE:
|
||||||
|
raise AssertionError("input flow ended prematurely")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.input_flow.send(code)
|
||||||
|
except StopIteration:
|
||||||
|
self.input_flow = self.INPUT_FLOW_DONE
|
||||||
|
|
||||||
def get_pin(self):
|
def get_pin(self, code=None):
|
||||||
if self.pin:
|
if self.pin:
|
||||||
return self.pin
|
return self.pin
|
||||||
else:
|
else:
|
||||||
@ -154,12 +166,10 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# of unit testing, because it will fail to work
|
# of unit testing, because it will fail to work
|
||||||
# without special DebugLink interface provided
|
# without special DebugLink interface provided
|
||||||
# by the device.
|
# by the device.
|
||||||
DEBUG = LOG.getChild("debug_link").debug
|
|
||||||
|
|
||||||
def __init__(self, transport):
|
def __init__(self, transport):
|
||||||
self.debug = DebugLink(transport.find_debug())
|
self.debug = DebugLink(transport.find_debug())
|
||||||
self.ui = DebugUI(self.debug)
|
self.ui = DebugUI(self.debug)
|
||||||
super().__init__(transport, self.ui)
|
|
||||||
|
|
||||||
self.in_with_statement = 0
|
self.in_with_statement = 0
|
||||||
self.button_wait = 0
|
self.button_wait = 0
|
||||||
@ -170,9 +180,11 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
# Do not expect any specific response from device
|
# Do not expect any specific response from device
|
||||||
self.expected_responses = None
|
self.expected_responses = None
|
||||||
|
self.current_response = None
|
||||||
|
|
||||||
# Use blank passphrase
|
# Use blank passphrase
|
||||||
self.set_passphrase("")
|
self.set_passphrase("")
|
||||||
|
super().__init__(transport, ui=self.ui)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super().close()
|
super().close()
|
||||||
@ -182,6 +194,14 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
def set_buttonwait(self, secs):
|
def set_buttonwait(self, secs):
|
||||||
self.button_wait = secs
|
self.button_wait = secs
|
||||||
|
|
||||||
|
def set_input_flow(self, input_flow):
|
||||||
|
if callable(input_flow):
|
||||||
|
input_flow = input_flow()
|
||||||
|
if not hasattr(input_flow, "send"):
|
||||||
|
raise RuntimeError("input_flow should be a generator function")
|
||||||
|
self.ui.input_flow = input_flow
|
||||||
|
next(input_flow) # can't send before first yield
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# For usage in with/expected_responses
|
# For usage in with/expected_responses
|
||||||
self.in_with_statement += 1
|
self.in_with_statement += 1
|
||||||
@ -196,20 +216,19 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
# return isinstance(value, TypeError)
|
# return isinstance(value, TypeError)
|
||||||
# Evaluate missed responses in 'with' statement
|
# Evaluate missed responses in 'with' statement
|
||||||
if self.expected_responses is not None and len(self.expected_responses):
|
if self.current_response < len(self.expected_responses):
|
||||||
raise RuntimeError(
|
self._raise_unexpected_response(None)
|
||||||
"Some of expected responses didn't come from device: %s"
|
|
||||||
% [repr(x) for x in self.expected_responses]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
self.expected_responses = None
|
self.expected_responses = None
|
||||||
|
self.current_response = None
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def set_expected_responses(self, expected):
|
def set_expected_responses(self, expected):
|
||||||
if not self.in_with_statement:
|
if not self.in_with_statement:
|
||||||
raise RuntimeError("Must be called inside 'with' statement")
|
raise RuntimeError("Must be called inside 'with' statement")
|
||||||
self.expected_responses = expected
|
self.expected_responses = expected
|
||||||
|
self.current_response = 0
|
||||||
|
|
||||||
def setup_debuglink(self, button, pin_correct):
|
def setup_debuglink(self, button, pin_correct):
|
||||||
# self.button = button # True -> YES button, False -> NO button
|
# self.button = button # True -> YES button, False -> NO button
|
||||||
@ -224,7 +243,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
def set_mnemonic(self, mnemonic):
|
def set_mnemonic(self, mnemonic):
|
||||||
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
||||||
|
|
||||||
def call_raw(self, msg):
|
def _raw_read(self):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
# if SCREENSHOT and self.debug:
|
# if SCREENSHOT and self.debug:
|
||||||
@ -241,36 +260,63 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# im.save("scr%05d.png" % self.screenshot_id)
|
# im.save("scr%05d.png" % self.screenshot_id)
|
||||||
# self.screenshot_id += 1
|
# self.screenshot_id += 1
|
||||||
|
|
||||||
resp = super().call_raw(msg)
|
resp = super()._raw_read()
|
||||||
self._check_request(resp)
|
self._check_request(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _check_request(self, msg):
|
def _raise_unexpected_response(self, msg):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
if self.expected_responses is not None:
|
output = []
|
||||||
try:
|
output.append("Expected responses:")
|
||||||
expected = self.expected_responses.pop(0)
|
for i, exp in enumerate(self.expected_responses):
|
||||||
except IndexError:
|
prefix = " " if i != self.current_response else ">>> "
|
||||||
|
set_fields = {
|
||||||
|
key: value
|
||||||
|
for key, value in exp.__dict__.items()
|
||||||
|
if value is not None and value != []
|
||||||
|
}
|
||||||
|
oneline_str = ", ".join("{}={!r}".format(*i) for i in set_fields.items())
|
||||||
|
if len(oneline_str) < 60:
|
||||||
|
output.append(
|
||||||
|
"{}{}({})".format(prefix, exp.__class__.__name__, oneline_str)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output.append("{}{}(".format(prefix, exp.__class__.__name__))
|
||||||
|
for key, value in set_fields.items():
|
||||||
|
output.append("{} {}={!r}".format(prefix, key, value))
|
||||||
|
output.append("{})".format(prefix))
|
||||||
|
|
||||||
|
output.append("")
|
||||||
|
if msg is not None:
|
||||||
|
output.append("Actually received:")
|
||||||
|
output.append(format_message(msg))
|
||||||
|
else:
|
||||||
|
output.append("This message was never received.")
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
|
def _check_request(self, msg):
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
if self.expected_responses is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.current_response >= len(self.expected_responses):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
proto.FailureType.UnexpectedMessage,
|
"No more messages were expected, but we got:\n" + format_message(msg)
|
||||||
"Got %s, but no message has been expected" % repr(msg),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
expected = self.expected_responses[self.current_response]
|
||||||
|
|
||||||
if msg.__class__ != expected.__class__:
|
if msg.__class__ != expected.__class__:
|
||||||
raise AssertionError(
|
self._raise_unexpected_response(msg)
|
||||||
proto.FailureType.UnexpectedMessage,
|
|
||||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
|
||||||
)
|
|
||||||
|
|
||||||
for field, value in expected.__dict__.items():
|
for field, value in expected.__dict__.items():
|
||||||
if value is None or value == []:
|
if value is None or value == []:
|
||||||
continue
|
continue
|
||||||
if getattr(msg, field) != value:
|
if getattr(msg, field) != value:
|
||||||
raise AssertionError(
|
self._raise_unexpected_response(msg)
|
||||||
proto.FailureType.UnexpectedMessage,
|
|
||||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
self.current_response += 1
|
||||||
)
|
|
||||||
|
|
||||||
def mnemonic_callback(self, _):
|
def mnemonic_callback(self, _):
|
||||||
word, pos = self.debug.read_recovery_word()
|
word, pos = self.debug.read_recovery_word()
|
||||||
|
@ -63,7 +63,7 @@ class TrezorTest:
|
|||||||
label="test",
|
label="test",
|
||||||
language="english",
|
language="english",
|
||||||
)
|
)
|
||||||
if passphrase:
|
if conftest.TREZOR_VERSION > 1 and passphrase:
|
||||||
device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST)
|
device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST)
|
||||||
|
|
||||||
def setup_mnemonic_allallall(self):
|
def setup_mnemonic_allallall(self):
|
||||||
|
@ -20,8 +20,9 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from trezorlib import coins, log
|
from trezorlib import coins, log
|
||||||
from trezorlib.client import TrezorClient, TrezorClientDebugLink
|
from trezorlib.debuglink import TrezorClientDebugLink
|
||||||
from trezorlib.transport import enumerate_devices, get_transport
|
from trezorlib.transport import enumerate_devices, get_transport
|
||||||
|
from trezorlib import device, debuglink
|
||||||
|
|
||||||
TREZOR_VERSION = None
|
TREZOR_VERSION = None
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ def device_version():
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
if not device:
|
if not device:
|
||||||
raise RuntimeError()
|
raise RuntimeError()
|
||||||
client = TrezorClient(device)
|
client = TrezorClientDebugLink(device)
|
||||||
if client.features.model == "T":
|
if client.features.model == "T":
|
||||||
return 2
|
return 2
|
||||||
else:
|
else:
|
||||||
@ -52,11 +53,9 @@ def device_version():
|
|||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def client():
|
def client():
|
||||||
wirelink = get_device()
|
wirelink = get_device()
|
||||||
debuglink = wirelink.find_debug()
|
|
||||||
client = TrezorClientDebugLink(wirelink)
|
client = TrezorClientDebugLink(wirelink)
|
||||||
client.set_debuglink(debuglink)
|
|
||||||
client.set_tx_api(coins.tx_api["Bitcoin"])
|
client.set_tx_api(coins.tx_api["Bitcoin"])
|
||||||
client.wipe_device()
|
device.wipe(client)
|
||||||
client.transport.session_begin()
|
client.transport.session_begin()
|
||||||
|
|
||||||
yield client
|
yield client
|
||||||
@ -78,7 +77,8 @@ def setup_client(mnemonic=None, pin="", passphrase=False):
|
|||||||
def client_decorator(function):
|
def client_decorator(function):
|
||||||
@functools.wraps(function)
|
@functools.wraps(function)
|
||||||
def wrapper(client, *args, **kwargs):
|
def wrapper(client, *args, **kwargs):
|
||||||
client.load_device_by_mnemonic(
|
debuglink.load_device_by_mnemonic(
|
||||||
|
client,
|
||||||
mnemonic=mnemonic,
|
mnemonic=mnemonic,
|
||||||
pin=pin,
|
pin=pin,
|
||||||
passphrase_protection=passphrase,
|
passphrase_protection=passphrase,
|
||||||
|
@ -188,6 +188,7 @@ class expect:
|
|||||||
def __call__(self, f):
|
def __call__(self, f):
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapped_f(*args, **kwargs):
|
def wrapped_f(*args, **kwargs):
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
ret = f(*args, **kwargs)
|
ret = f(*args, **kwargs)
|
||||||
if not isinstance(ret, self.expected):
|
if not isinstance(ret, self.expected):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
Loading…
Reference in New Issue
Block a user