1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-23 13:51:00 +00:00

feat(python/debuglink): streamline expected responses handling [no changelog]

This commit is contained in:
matejcik 2021-06-17 16:27:14 +02:00 committed by matejcik
parent 1012ee8497
commit ea2a9375ac

View File

@ -19,6 +19,7 @@ import textwrap
from collections import namedtuple
from copy import deepcopy
from enum import IntEnum
from itertools import zip_longest
from mnemonic import Mnemonic
@ -379,7 +380,7 @@ class TrezorClientDebugLink(TrezorClient):
# Do not expect any specific response from device
self.expected_responses = None
self.current_response = None
self.actual_responses = None
super().__init__(transport, ui=self.ui)
@ -464,31 +465,27 @@ class TrezorClientDebugLink(TrezorClient):
def __enter__(self):
# For usage in with/expected_responses
self.in_with_statement += 1
if self.in_with_statement > 1:
raise RuntimeError("Do not nest!")
return self
def __exit__(self, _type, value, traceback):
self.in_with_statement -= 1
__tracebackhide__ = True # for pytest # pylint: disable=W0612
self.in_with_statement -= 1
self.ui.clear()
self.watch_layout(False)
# Clear input flow.
try:
if _type is not None:
# Another exception raised
return False
if self.expected_responses is None:
# no need to check anything else
return False
try:
# Evaluate missed responses in 'with' statement
if self.current_response < len(self.expected_responses):
self._raise_unexpected_response(None)
self._verify_responses(self.expected_responses, self.actual_responses)
finally:
# Cleanup
self.expected_responses = None
self.current_response = None
self.ui.clear()
self.watch_layout(False)
self.actual_responses = None
return False
@ -528,8 +525,7 @@ class TrezorClientDebugLink(TrezorClient):
for valid, expected in expected_with_validity
if valid
]
self.current_response = 0
self.actual_responses = []
def use_pin_sequence(self, pins):
"""Respond to PIN prompts from device with the provided PINs.
@ -551,57 +547,57 @@ class TrezorClientDebugLink(TrezorClient):
resp = super()._raw_read()
resp = self._filter_message(resp)
self._check_request(resp)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def _raw_write(self, msg):
return super()._raw_write(self._filter_message(msg))
def _raise_unexpected_response(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
start_at = max(self.current_response - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(
self.current_response + EXPECTED_RESPONSES_CONTEXT_LINES + 1,
len(self.expected_responses),
)
def _expectation_lines(self, expected, current):
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output = []
output.append("Expected responses:")
if start_at > 0:
output.append(" (...{} previous responses omitted)".format(start_at))
output.append(f" (...{start_at} previous responses omitted)")
for i in range(start_at, stop_at):
exp = self.expected_responses[i]
prefix = " " if i != self.current_response else ">>> "
exp = expected[i]
prefix = " " if i != current else ">>> "
output.append(textwrap.indent(exp.format(), prefix))
if stop_at < len(self.expected_responses):
omitted = len(self.expected_responses) - stop_at
output.append(" (...{} following responses omitted)".format(omitted))
if stop_at < len(expected):
omitted = len(expected) - stop_at
output.append(f" (...{omitted} following responses omitted)")
output.append("")
if msg is not None:
output.append("Actually received:")
output.append(textwrap.indent(protobuf.format_message(msg), " "))
else:
output.append("This message was never received.")
raise AssertionError("\n".join(output))
return output
def _check_request(self, msg):
def _verify_responses(self, expected, actual):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if self.expected_responses is None:
if expected is None and actual is None:
return
if self.current_response >= len(self.expected_responses):
raise AssertionError(
"No more messages were expected, but we got:\n"
+ protobuf.format_message(msg)
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
if exp is None:
output = self._expectation_lines(expected, i)
output.append("No more messages were expected, but we got:")
for resp in actual[i:]:
output.append(
textwrap.indent(protobuf.format_message(resp), " ")
)
raise AssertionError("\n".join(output))
expected = self.expected_responses[self.current_response]
if act is None:
output = self._expectation_lines(expected, i)
output.append("This and the following message was not received.")
raise AssertionError("\n".join(output))
if not expected.match(msg):
self._raise_unexpected_response(msg)
self.current_response += 1
if not exp.match(act):
output = self._expectation_lines(expected, i)
output.append("Actually received:")
output.append(textwrap.indent(protobuf.format_message(act), " "))
raise AssertionError("\n".join(output))
def mnemonic_callback(self, _):
word, pos = self.debug.read_recovery_word()