diff --git a/trezorlib/client.py b/trezorlib/client.py index 835a4dc679..846d8cc695 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -15,8 +15,7 @@ def get_buttonrequest_value(code): return [ k for k, v in types.ButtonRequestType.items() if v == code][0] def pprint(msg): - ser = msg.SerializeToString() - return "<%s> (%d bytes):\n%s" % (msg.__class__.__name__, len(ser), msg) + return "<%s> (%d bytes):\n%s" % (msg.__class__.__name__, msg.ByteSize(), msg) class CallException(Exception): def __init__(self, code, message): @@ -62,13 +61,23 @@ class BaseClient(object): self.transport = transport super(BaseClient, self).__init__() # *args, **kwargs) - def call(self, msg): + def call_raw(self, msg): try: self.transport.session_begin() self.transport.write(msg) resp = self.transport.read_blocking() + finally: + self.transport.session_end() + + return resp + + def call(self, msg): + try: + self.transport.session_begin() + + resp = self.call_raw(msg) handler_name = "callback_%s" % resp.__class__.__name__ handler = getattr(self, handler_name, None) @@ -138,8 +147,11 @@ class DebugLinkMixin(object): # Always press Yes and provide correct pin self.setup_debuglink(True, True) - # Do not expect any specific ButtonRequest - self.set_expected_buttonrequests(None) + # Do not expect any specific response from device + self.set_expected_responses(None) + + # Use blank passphrase + self.set_passphrase('') def close(self): super(DebugLinkMixin, self).close() @@ -149,33 +161,50 @@ class DebugLinkMixin(object): def set_debuglink(self, debug_transport): self.debug = DebugLink(debug_transport) - def set_expected_buttonrequests(self, expected): - self.expected_buttonrequests = expected + def set_expected_responses(self, expected): + self.expected_responses = expected def setup_debuglink(self, button, pin_correct): self.button = button # True -> YES button, False -> NO button self.pin_correct = pin_correct + def set_passphrase(self, passphrase): + self.passphrase = passphrase + + def call_raw(self, msg): + resp = super(DebugLinkMixin, self).call_raw(msg) + self._check_request(resp) + return resp + def call(self, msg): print "SENDING", pprint(msg) ret = super(DebugLinkMixin, self).call(msg) print "RECEIVED", pprint(ret) + + if self.expected_responses != None and len(self.expected_responses): + raise Exception("Some of expected responses didn't come from device: %s" % \ + [ pprint(x) for x in self.expected_responses ]) return ret - def callback_ButtonRequest(self, msg): - if self.expected_buttonrequests != None: + def _check_request(self, msg): + if self.expected_responses != None: try: - expected = self.expected_buttonrequests.pop(0) - if msg.code != expected: - raise CallException(types.Failure_Other, - "Expected %s, got %s" % \ - (get_buttonrequest_value(expected), - get_buttonrequest_value(msg.code))) + expected = self.expected_responses.pop(0) except IndexError: raise CallException(types.Failure_Other, - "Got %s, but no ButtonRequest has been expected" % \ - get_buttonrequest_value(msg.code)) + "Got %s, but no message has been expected" % pprint(msg)) + if msg.__class__ != expected.__class__: + raise CallException(types.Failure_Other, + "Expected %s, got %s" % (pprint(expected), pprint(msg))) + + fields = expected.ListFields() # only filled (including extensions) + for field, value in fields: + if not msg.HasField(field.name) or getattr(msg, field.name) != value: + raise CallException(types.Failure_Other, + "Expected %s, got %s" % (pprint(expected), pprint(msg))) + + def callback_ButtonRequest(self, msg): print "ButtonRequest code:", get_buttonrequest_value(msg.code) print "Pressing button", self.button @@ -190,7 +219,8 @@ class DebugLinkMixin(object): return proto.PinMatrixAck(pin=pin) def callback_PassphraseRequest(self, msg): - raise Exception("Not implemented yet") + print "Provided passphrase: '%s'" % self.passphrase + return proto.PassphraseAck(passphrase=self.passphrase) def callback_WordRequest(self, msg): raise Exception("Not implemented yet")