diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 889475dc6..37ce1e679 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -374,19 +374,21 @@ class TrezorClientDebugLink(TrezorClient): else: raise - self.ui = DebugUI(self.debug) - - self.in_with_statement = 0 - self.screenshot_id = 0 - - self.filters = {} - - # Do not expect any specific response from device - self.expected_responses = None - self.actual_responses = None + self.reset_debug_features() super().__init__(transport, ui=self.ui) + def reset_debug_features(self): + """Prepare the debugging client for a new testcase. + + Clears all debugging state that might have been modified by a testcase. + """ + self.ui = DebugUI(self.debug) + self.in_with_statement = False + self.expected_responses = None + self.actual_responses = None + self.filters = {} + def open(self): super().open() if self.session_counter == 1: @@ -470,31 +472,24 @@ class TrezorClientDebugLink(TrezorClient): def __enter__(self): # For usage in with/expected_responses - self.in_with_statement += 1 - if self.in_with_statement > 1: + if self.in_with_statement: raise RuntimeError("Do not nest!") + self.in_with_statement = True return self - def __exit__(self, _type, value, traceback): + def __exit__(self, exc_type, value, traceback): __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self.in_with_statement -= 1 - self.ui.clear() - self.filters.clear() self.watch_layout(False) + # copy expected/actual responses before clearing them + expected_responses = self.expected_responses + actual_responses = self.actual_responses + self.reset_debug_features() - if _type is not None: - # Another exception raised - return False - - try: - # Evaluate missed responses in 'with' statement - self._verify_responses(self.expected_responses, self.actual_responses) - finally: - self.expected_responses = None - self.actual_responses = None - - return False + if exc_type is None: + # If no other exception was raised, evaluate missed responses + # (raises AssertionError on mismatch) + self._verify_responses(expected_responses, actual_responses) def set_expected_responses(self, expected): """Set a sequence of expected responses to client calls. @@ -561,7 +556,8 @@ class TrezorClientDebugLink(TrezorClient): def _raw_write(self, msg): return super()._raw_write(self._filter_message(msg)) - def _expectation_lines(self, expected, current): + @staticmethod + def _expectation_lines(expected, current): start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) output = [] @@ -579,7 +575,8 @@ class TrezorClientDebugLink(TrezorClient): output.append("") return output - def _verify_responses(self, expected, actual): + @classmethod + def _verify_responses(cls, expected, actual): __tracebackhide__ = True # for pytest # pylint: disable=W0612 if expected is None and actual is None: @@ -587,7 +584,7 @@ class TrezorClientDebugLink(TrezorClient): for i, (exp, act) in enumerate(zip_longest(expected, actual)): if exp is None: - output = self._expectation_lines(expected, i) + output = cls._expectation_lines(expected, i) output.append("No more messages were expected, but we got:") for resp in actual[i:]: output.append( @@ -596,12 +593,12 @@ class TrezorClientDebugLink(TrezorClient): raise AssertionError("\n".join(output)) if act is None: - output = self._expectation_lines(expected, i) + output = cls._expectation_lines(expected, i) output.append("This and the following message was not received.") raise AssertionError("\n".join(output)) if not exp.match(act): - output = self._expectation_lines(expected, i) + output = cls._expectation_lines(expected, i) output.append("Actually received:") output.append(textwrap.indent(protobuf.format_message(act), " ")) raise AssertionError("\n".join(output))