diff --git a/trezorlib/client.py b/trezorlib/client.py index b08eea2fa..a5dfff445 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -148,7 +148,7 @@ class TextUIMixin(object): def callback_WordRequest(self, msg): word = raw_input("Enter one word of mnemonic: ") return proto.WordAck(word=word) - + class DebugLinkMixin(object): # This class implements automatic responses # and other functionality for unit tests @@ -163,12 +163,13 @@ class DebugLinkMixin(object): def __init__(self, *args, **kwargs): super(DebugLinkMixin, self).__init__(*args, **kwargs) self.debug = None + self.in_with_statement = 0 # Always press Yes and provide correct pin self.setup_debuglink(True, True) # Do not expect any specific response from device - self.set_expected_responses(None) + self.expected_responses = None # Use blank passphrase self.set_passphrase('') @@ -181,7 +182,26 @@ class DebugLinkMixin(object): def set_debuglink(self, debug_transport): self.debug = DebugLink(debug_transport) + def __enter__(self): + # For usage in with/expected_responses + self.in_with_statement += 1 + return self + + def __exit__(self, *args): + self.in_with_statement -= 1 + + # Evaluate missed responses in 'with' statement + if self.expected_responses != None and len(self.expected_responses): + raise Exception("Some of expected responses didn't come from device: %s" % \ + [ pprint(x) for x in self.expected_responses ]) + + # Cleanup + self.expected_responses = None + return False + def set_expected_responses(self, expected): + if not self.in_with_statement: + raise Exception("Must be called inside 'with' statement") self.expected_responses = expected def setup_debuglink(self, button, pin_correct): @@ -195,14 +215,6 @@ class DebugLinkMixin(object): resp = super(DebugLinkMixin, self).call_raw(msg) self._check_request(resp) return resp - - def call(self, msg): - ret = super(DebugLinkMixin, self).call(msg) - - if self.expected_responses != None and len(self.expected_responses): - raise Exception("Some of expected responses didn't come from device: %s" % \ - [ pprint(x) for x in self.expected_responses ]) - return ret def _check_request(self, msg): if self.expected_responses != None: