mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-23 07:58:09 +00:00
TrezorDebugClient: Removed expected_buttonrequests, added more generic expected_responses
This commit is contained in:
parent
f48cf157c7
commit
86a2a9f845
@ -15,8 +15,7 @@ def get_buttonrequest_value(code):
|
|||||||
return [ k for k, v in types.ButtonRequestType.items() if v == code][0]
|
return [ k for k, v in types.ButtonRequestType.items() if v == code][0]
|
||||||
|
|
||||||
def pprint(msg):
|
def pprint(msg):
|
||||||
ser = msg.SerializeToString()
|
return "<%s> (%d bytes):\n%s" % (msg.__class__.__name__, msg.ByteSize(), msg)
|
||||||
return "<%s> (%d bytes):\n%s" % (msg.__class__.__name__, len(ser), msg)
|
|
||||||
|
|
||||||
class CallException(Exception):
|
class CallException(Exception):
|
||||||
def __init__(self, code, message):
|
def __init__(self, code, message):
|
||||||
@ -62,13 +61,23 @@ class BaseClient(object):
|
|||||||
self.transport = transport
|
self.transport = transport
|
||||||
super(BaseClient, self).__init__() # *args, **kwargs)
|
super(BaseClient, self).__init__() # *args, **kwargs)
|
||||||
|
|
||||||
def call(self, msg):
|
def call_raw(self, msg):
|
||||||
try:
|
try:
|
||||||
self.transport.session_begin()
|
self.transport.session_begin()
|
||||||
|
|
||||||
self.transport.write(msg)
|
self.transport.write(msg)
|
||||||
resp = self.transport.read_blocking()
|
resp = self.transport.read_blocking()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.transport.session_end()
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def call(self, msg):
|
||||||
|
try:
|
||||||
|
self.transport.session_begin()
|
||||||
|
|
||||||
|
resp = self.call_raw(msg)
|
||||||
handler_name = "callback_%s" % resp.__class__.__name__
|
handler_name = "callback_%s" % resp.__class__.__name__
|
||||||
handler = getattr(self, handler_name, None)
|
handler = getattr(self, handler_name, None)
|
||||||
|
|
||||||
@ -138,8 +147,11 @@ class DebugLinkMixin(object):
|
|||||||
# Always press Yes and provide correct pin
|
# Always press Yes and provide correct pin
|
||||||
self.setup_debuglink(True, True)
|
self.setup_debuglink(True, True)
|
||||||
|
|
||||||
# Do not expect any specific ButtonRequest
|
# Do not expect any specific response from device
|
||||||
self.set_expected_buttonrequests(None)
|
self.set_expected_responses(None)
|
||||||
|
|
||||||
|
# Use blank passphrase
|
||||||
|
self.set_passphrase('')
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super(DebugLinkMixin, self).close()
|
super(DebugLinkMixin, self).close()
|
||||||
@ -149,33 +161,50 @@ class DebugLinkMixin(object):
|
|||||||
def set_debuglink(self, debug_transport):
|
def set_debuglink(self, debug_transport):
|
||||||
self.debug = DebugLink(debug_transport)
|
self.debug = DebugLink(debug_transport)
|
||||||
|
|
||||||
def set_expected_buttonrequests(self, expected):
|
def set_expected_responses(self, expected):
|
||||||
self.expected_buttonrequests = expected
|
self.expected_responses = expected
|
||||||
|
|
||||||
def setup_debuglink(self, button, pin_correct):
|
def setup_debuglink(self, button, pin_correct):
|
||||||
self.button = button # True -> YES button, False -> NO button
|
self.button = button # True -> YES button, False -> NO button
|
||||||
self.pin_correct = pin_correct
|
self.pin_correct = pin_correct
|
||||||
|
|
||||||
|
def set_passphrase(self, passphrase):
|
||||||
|
self.passphrase = passphrase
|
||||||
|
|
||||||
|
def call_raw(self, msg):
|
||||||
|
resp = super(DebugLinkMixin, self).call_raw(msg)
|
||||||
|
self._check_request(resp)
|
||||||
|
return resp
|
||||||
|
|
||||||
def call(self, msg):
|
def call(self, msg):
|
||||||
print "SENDING", pprint(msg)
|
print "SENDING", pprint(msg)
|
||||||
ret = super(DebugLinkMixin, self).call(msg)
|
ret = super(DebugLinkMixin, self).call(msg)
|
||||||
print "RECEIVED", pprint(ret)
|
print "RECEIVED", pprint(ret)
|
||||||
|
|
||||||
|
if self.expected_responses != None and len(self.expected_responses):
|
||||||
|
raise Exception("Some of expected responses didn't come from device: %s" % \
|
||||||
|
[ pprint(x) for x in self.expected_responses ])
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def callback_ButtonRequest(self, msg):
|
def _check_request(self, msg):
|
||||||
if self.expected_buttonrequests != None:
|
if self.expected_responses != None:
|
||||||
try:
|
try:
|
||||||
expected = self.expected_buttonrequests.pop(0)
|
expected = self.expected_responses.pop(0)
|
||||||
if msg.code != expected:
|
|
||||||
raise CallException(types.Failure_Other,
|
|
||||||
"Expected %s, got %s" % \
|
|
||||||
(get_buttonrequest_value(expected),
|
|
||||||
get_buttonrequest_value(msg.code)))
|
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise CallException(types.Failure_Other,
|
raise CallException(types.Failure_Other,
|
||||||
"Got %s, but no ButtonRequest has been expected" % \
|
"Got %s, but no message has been expected" % pprint(msg))
|
||||||
get_buttonrequest_value(msg.code))
|
|
||||||
|
|
||||||
|
if msg.__class__ != expected.__class__:
|
||||||
|
raise CallException(types.Failure_Other,
|
||||||
|
"Expected %s, got %s" % (pprint(expected), pprint(msg)))
|
||||||
|
|
||||||
|
fields = expected.ListFields() # only filled (including extensions)
|
||||||
|
for field, value in fields:
|
||||||
|
if not msg.HasField(field.name) or getattr(msg, field.name) != value:
|
||||||
|
raise CallException(types.Failure_Other,
|
||||||
|
"Expected %s, got %s" % (pprint(expected), pprint(msg)))
|
||||||
|
|
||||||
|
def callback_ButtonRequest(self, msg):
|
||||||
print "ButtonRequest code:", get_buttonrequest_value(msg.code)
|
print "ButtonRequest code:", get_buttonrequest_value(msg.code)
|
||||||
|
|
||||||
print "Pressing button", self.button
|
print "Pressing button", self.button
|
||||||
@ -190,7 +219,8 @@ class DebugLinkMixin(object):
|
|||||||
return proto.PinMatrixAck(pin=pin)
|
return proto.PinMatrixAck(pin=pin)
|
||||||
|
|
||||||
def callback_PassphraseRequest(self, msg):
|
def callback_PassphraseRequest(self, msg):
|
||||||
raise Exception("Not implemented yet")
|
print "Provided passphrase: '%s'" % self.passphrase
|
||||||
|
return proto.PassphraseAck(passphrase=self.passphrase)
|
||||||
|
|
||||||
def callback_WordRequest(self, msg):
|
def callback_WordRequest(self, msg):
|
||||||
raise Exception("Not implemented yet")
|
raise Exception("Not implemented yet")
|
||||||
|
Loading…
Reference in New Issue
Block a user