From ea2a9375ac84fe465111f747233c684fc5447386 Mon Sep 17 00:00:00 2001 From: matejcik Date: Thu, 17 Jun 2021 16:27:14 +0200 Subject: [PATCH] feat(python/debuglink): streamline expected responses handling [no changelog] --- python/src/trezorlib/debuglink.py | 100 ++++++++++++++---------------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 75947a8902..4783600b7e 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -19,6 +19,7 @@ import textwrap from collections import namedtuple from copy import deepcopy from enum import IntEnum +from itertools import zip_longest from mnemonic import Mnemonic @@ -379,7 +380,7 @@ class TrezorClientDebugLink(TrezorClient): # Do not expect any specific response from device self.expected_responses = None - self.current_response = None + self.actual_responses = None super().__init__(transport, ui=self.ui) @@ -464,31 +465,27 @@ class TrezorClientDebugLink(TrezorClient): def __enter__(self): # For usage in with/expected_responses self.in_with_statement += 1 + if self.in_with_statement > 1: + raise RuntimeError("Do not nest!") return self def __exit__(self, _type, value, traceback): + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + self.in_with_statement -= 1 + self.ui.clear() + self.watch_layout(False) + + if _type is not None: + # Another exception raised + return False - # Clear input flow. try: - if _type is not None: - # Another exception raised - return False - - if self.expected_responses is None: - # no need to check anything else - return False - # Evaluate missed responses in 'with' statement - if self.current_response < len(self.expected_responses): - self._raise_unexpected_response(None) - + self._verify_responses(self.expected_responses, self.actual_responses) finally: - # Cleanup self.expected_responses = None - self.current_response = None - self.ui.clear() - self.watch_layout(False) + self.actual_responses = None return False @@ -528,8 +525,7 @@ class TrezorClientDebugLink(TrezorClient): for valid, expected in expected_with_validity if valid ] - - self.current_response = 0 + self.actual_responses = [] def use_pin_sequence(self, pins): """Respond to PIN prompts from device with the provided PINs. @@ -551,57 +547,57 @@ class TrezorClientDebugLink(TrezorClient): resp = super()._raw_read() resp = self._filter_message(resp) - self._check_request(resp) + if self.actual_responses is not None: + self.actual_responses.append(resp) return resp def _raw_write(self, msg): return super()._raw_write(self._filter_message(msg)) - def _raise_unexpected_response(self, msg): - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - start_at = max(self.current_response - EXPECTED_RESPONSES_CONTEXT_LINES, 0) - stop_at = min( - self.current_response + EXPECTED_RESPONSES_CONTEXT_LINES + 1, - len(self.expected_responses), - ) + def _expectation_lines(self, expected, current): + start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) + stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) output = [] output.append("Expected responses:") if start_at > 0: - output.append(" (...{} previous responses omitted)".format(start_at)) + output.append(f" (...{start_at} previous responses omitted)") for i in range(start_at, stop_at): - exp = self.expected_responses[i] - prefix = " " if i != self.current_response else ">>> " + exp = expected[i] + prefix = " " if i != current else ">>> " output.append(textwrap.indent(exp.format(), prefix)) - if stop_at < len(self.expected_responses): - omitted = len(self.expected_responses) - stop_at - output.append(" (...{} following responses omitted)".format(omitted)) + if stop_at < len(expected): + omitted = len(expected) - stop_at + output.append(f" (...{omitted} following responses omitted)") output.append("") - if msg is not None: - output.append("Actually received:") - output.append(textwrap.indent(protobuf.format_message(msg), " ")) - else: - output.append("This message was never received.") - raise AssertionError("\n".join(output)) + return output - def _check_request(self, msg): + def _verify_responses(self, expected, actual): __tracebackhide__ = True # for pytest # pylint: disable=W0612 - if self.expected_responses is None: + + if expected is None and actual is None: return - if self.current_response >= len(self.expected_responses): - raise AssertionError( - "No more messages were expected, but we got:\n" - + protobuf.format_message(msg) - ) + for i, (exp, act) in enumerate(zip_longest(expected, actual)): + if exp is None: + output = self._expectation_lines(expected, i) + output.append("No more messages were expected, but we got:") + for resp in actual[i:]: + output.append( + textwrap.indent(protobuf.format_message(resp), " ") + ) + raise AssertionError("\n".join(output)) - expected = self.expected_responses[self.current_response] + if act is None: + output = self._expectation_lines(expected, i) + output.append("This and the following message was not received.") + raise AssertionError("\n".join(output)) - if not expected.match(msg): - self._raise_unexpected_response(msg) - - self.current_response += 1 + if not exp.match(act): + output = self._expectation_lines(expected, i) + output.append("Actually received:") + output.append(textwrap.indent(protobuf.format_message(act), " ")) + raise AssertionError("\n".join(output)) def mnemonic_callback(self, _): word, pos = self.debug.read_recovery_word()