mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-02 19:01:04 +00:00
feat(python/debuglink): streamline expected responses handling [no changelog]
This commit is contained in:
parent
1012ee8497
commit
ea2a9375ac
@ -19,6 +19,7 @@ import textwrap
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
|
from itertools import zip_longest
|
||||||
|
|
||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
@ -379,7 +380,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
# Do not expect any specific response from device
|
# Do not expect any specific response from device
|
||||||
self.expected_responses = None
|
self.expected_responses = None
|
||||||
self.current_response = None
|
self.actual_responses = None
|
||||||
|
|
||||||
super().__init__(transport, ui=self.ui)
|
super().__init__(transport, ui=self.ui)
|
||||||
|
|
||||||
@ -464,31 +465,27 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
# For usage in with/expected_responses
|
# For usage in with/expected_responses
|
||||||
self.in_with_statement += 1
|
self.in_with_statement += 1
|
||||||
|
if self.in_with_statement > 1:
|
||||||
|
raise RuntimeError("Do not nest!")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, _type, value, traceback):
|
def __exit__(self, _type, value, traceback):
|
||||||
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
|
||||||
self.in_with_statement -= 1
|
self.in_with_statement -= 1
|
||||||
|
self.ui.clear()
|
||||||
|
self.watch_layout(False)
|
||||||
|
|
||||||
|
if _type is not None:
|
||||||
|
# Another exception raised
|
||||||
|
return False
|
||||||
|
|
||||||
# Clear input flow.
|
|
||||||
try:
|
try:
|
||||||
if _type is not None:
|
|
||||||
# Another exception raised
|
|
||||||
return False
|
|
||||||
|
|
||||||
if self.expected_responses is None:
|
|
||||||
# no need to check anything else
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Evaluate missed responses in 'with' statement
|
# Evaluate missed responses in 'with' statement
|
||||||
if self.current_response < len(self.expected_responses):
|
self._verify_responses(self.expected_responses, self.actual_responses)
|
||||||
self._raise_unexpected_response(None)
|
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Cleanup
|
|
||||||
self.expected_responses = None
|
self.expected_responses = None
|
||||||
self.current_response = None
|
self.actual_responses = None
|
||||||
self.ui.clear()
|
|
||||||
self.watch_layout(False)
|
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -528,8 +525,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
for valid, expected in expected_with_validity
|
for valid, expected in expected_with_validity
|
||||||
if valid
|
if valid
|
||||||
]
|
]
|
||||||
|
self.actual_responses = []
|
||||||
self.current_response = 0
|
|
||||||
|
|
||||||
def use_pin_sequence(self, pins):
|
def use_pin_sequence(self, pins):
|
||||||
"""Respond to PIN prompts from device with the provided PINs.
|
"""Respond to PIN prompts from device with the provided PINs.
|
||||||
@ -551,57 +547,57 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
resp = super()._raw_read()
|
resp = super()._raw_read()
|
||||||
resp = self._filter_message(resp)
|
resp = self._filter_message(resp)
|
||||||
self._check_request(resp)
|
if self.actual_responses is not None:
|
||||||
|
self.actual_responses.append(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def _raw_write(self, msg):
|
def _raw_write(self, msg):
|
||||||
return super()._raw_write(self._filter_message(msg))
|
return super()._raw_write(self._filter_message(msg))
|
||||||
|
|
||||||
def _raise_unexpected_response(self, msg):
|
def _expectation_lines(self, expected, current):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||||
|
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||||
start_at = max(self.current_response - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
|
||||||
stop_at = min(
|
|
||||||
self.current_response + EXPECTED_RESPONSES_CONTEXT_LINES + 1,
|
|
||||||
len(self.expected_responses),
|
|
||||||
)
|
|
||||||
output = []
|
output = []
|
||||||
output.append("Expected responses:")
|
output.append("Expected responses:")
|
||||||
if start_at > 0:
|
if start_at > 0:
|
||||||
output.append(" (...{} previous responses omitted)".format(start_at))
|
output.append(f" (...{start_at} previous responses omitted)")
|
||||||
for i in range(start_at, stop_at):
|
for i in range(start_at, stop_at):
|
||||||
exp = self.expected_responses[i]
|
exp = expected[i]
|
||||||
prefix = " " if i != self.current_response else ">>> "
|
prefix = " " if i != current else ">>> "
|
||||||
output.append(textwrap.indent(exp.format(), prefix))
|
output.append(textwrap.indent(exp.format(), prefix))
|
||||||
if stop_at < len(self.expected_responses):
|
if stop_at < len(expected):
|
||||||
omitted = len(self.expected_responses) - stop_at
|
omitted = len(expected) - stop_at
|
||||||
output.append(" (...{} following responses omitted)".format(omitted))
|
output.append(f" (...{omitted} following responses omitted)")
|
||||||
|
|
||||||
output.append("")
|
output.append("")
|
||||||
if msg is not None:
|
return output
|
||||||
output.append("Actually received:")
|
|
||||||
output.append(textwrap.indent(protobuf.format_message(msg), " "))
|
|
||||||
else:
|
|
||||||
output.append("This message was never received.")
|
|
||||||
raise AssertionError("\n".join(output))
|
|
||||||
|
|
||||||
def _check_request(self, msg):
|
def _verify_responses(self, expected, actual):
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
if self.expected_responses is None:
|
|
||||||
|
if expected is None and actual is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.current_response >= len(self.expected_responses):
|
for i, (exp, act) in enumerate(zip_longest(expected, actual)):
|
||||||
raise AssertionError(
|
if exp is None:
|
||||||
"No more messages were expected, but we got:\n"
|
output = self._expectation_lines(expected, i)
|
||||||
+ protobuf.format_message(msg)
|
output.append("No more messages were expected, but we got:")
|
||||||
)
|
for resp in actual[i:]:
|
||||||
|
output.append(
|
||||||
|
textwrap.indent(protobuf.format_message(resp), " ")
|
||||||
|
)
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
expected = self.expected_responses[self.current_response]
|
if act is None:
|
||||||
|
output = self._expectation_lines(expected, i)
|
||||||
|
output.append("This and the following message was not received.")
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
if not expected.match(msg):
|
if not exp.match(act):
|
||||||
self._raise_unexpected_response(msg)
|
output = self._expectation_lines(expected, i)
|
||||||
|
output.append("Actually received:")
|
||||||
self.current_response += 1
|
output.append(textwrap.indent(protobuf.format_message(act), " "))
|
||||||
|
raise AssertionError("\n".join(output))
|
||||||
|
|
||||||
def mnemonic_callback(self, _):
|
def mnemonic_callback(self, _):
|
||||||
word, pos = self.debug.read_recovery_word()
|
word, pos = self.debug.read_recovery_word()
|
||||||
|
Loading…
Reference in New Issue
Block a user