1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 07:50:57 +00:00

set_expected_responses enforces using 'with' statement

This commit is contained in:
slush0 2014-02-21 07:28:10 +01:00
parent 9310465946
commit eae7d98b8a

View File

@ -148,7 +148,7 @@ class TextUIMixin(object):
def callback_WordRequest(self, msg): def callback_WordRequest(self, msg):
word = raw_input("Enter one word of mnemonic: ") word = raw_input("Enter one word of mnemonic: ")
return proto.WordAck(word=word) return proto.WordAck(word=word)
class DebugLinkMixin(object): class DebugLinkMixin(object):
# This class implements automatic responses # This class implements automatic responses
# and other functionality for unit tests # and other functionality for unit tests
@ -163,12 +163,13 @@ class DebugLinkMixin(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DebugLinkMixin, self).__init__(*args, **kwargs) super(DebugLinkMixin, self).__init__(*args, **kwargs)
self.debug = None self.debug = None
self.in_with_statement = 0
# Always press Yes and provide correct pin # Always press Yes and provide correct pin
self.setup_debuglink(True, True) self.setup_debuglink(True, True)
# Do not expect any specific response from device # Do not expect any specific response from device
self.set_expected_responses(None) self.expected_responses = None
# Use blank passphrase # Use blank passphrase
self.set_passphrase('') self.set_passphrase('')
@ -181,7 +182,26 @@ class DebugLinkMixin(object):
def set_debuglink(self, debug_transport): def set_debuglink(self, debug_transport):
self.debug = DebugLink(debug_transport) self.debug = DebugLink(debug_transport)
def __enter__(self):
# For usage in with/expected_responses
self.in_with_statement += 1
return self
def __exit__(self, *args):
self.in_with_statement -= 1
# Evaluate missed responses in 'with' statement
if self.expected_responses != None and len(self.expected_responses):
raise Exception("Some of expected responses didn't come from device: %s" % \
[ pprint(x) for x in self.expected_responses ])
# Cleanup
self.expected_responses = None
return False
def set_expected_responses(self, expected): def set_expected_responses(self, expected):
if not self.in_with_statement:
raise Exception("Must be called inside 'with' statement")
self.expected_responses = expected self.expected_responses = expected
def setup_debuglink(self, button, pin_correct): def setup_debuglink(self, button, pin_correct):
@ -195,14 +215,6 @@ class DebugLinkMixin(object):
resp = super(DebugLinkMixin, self).call_raw(msg) resp = super(DebugLinkMixin, self).call_raw(msg)
self._check_request(resp) self._check_request(resp)
return resp return resp
def call(self, msg):
ret = super(DebugLinkMixin, self).call(msg)
if self.expected_responses != None and len(self.expected_responses):
raise Exception("Some of expected responses didn't come from device: %s" % \
[ pprint(x) for x in self.expected_responses ])
return ret
def _check_request(self, msg): def _check_request(self, msg):
if self.expected_responses != None: if self.expected_responses != None: