add check for expected type of message responses

pull/25/head
Pavol Rusnak 11 years ago
parent 1089f58eb5
commit 24861a1b58

@ -30,8 +30,8 @@ class TestBasic(common.TrezorTest):
self.client.init_device()
uuid2 = self.client.get_device_id()
# UUID must be longer than 10 characters
self.assertEqual(len(uuid1), 12)
# UUID must be at least 12 characters
self.assertTrue(len(uuid1) >= 12)
# Every resulf of UUID must be the same
self.assertEqual(uuid1, uuid2)

@ -88,7 +88,7 @@ class TrezorClient(object):
return path
def init_device(self):
self.features = self.call(proto.Initialize())
self.features = self.call(proto.Initialize(), proto.Features)
def close(self):
self.transport.close()
@ -96,17 +96,17 @@ class TrezorClient(object):
self.debuglink.transport.close()
def get_public_node(self, n):
return self.call(proto.GetPublicKey(address_n=n)).node
return self.call(proto.GetPublicKey(address_n=n), proto.PublicKey).node
def get_address(self, coin_name, n):
n = self._convert_prime(n)
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name)).address
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name), proto.Address).address
def get_entropy(self, size):
return self.call(proto.GetEntropy(size=size)).entropy
return self.call(proto.GetEntropy(size=size), proto.Entropy).entropy
def ping(self, msg):
return self.call(proto.Ping(message=msg)).message
return self.call(proto.Ping(message=msg), proto.Success).message
def get_device_id(self):
return self.features.device_id
@ -118,7 +118,7 @@ class TrezorClient(object):
if language:
settings.language = language
out = self.call(settings).message
out = self.call(settings, proto.Success).message
self.init_device() # Reload Features
return out
@ -130,7 +130,7 @@ class TrezorClient(object):
self.debug_button = button
self.debug_pin = pin_correct
def call(self, msg):
def call(self, msg, expected = None):
if self.debug:
print '----------------------'
print "Sending", self._pprint(msg)
@ -187,6 +187,9 @@ class TrezorClient(object):
if self.debug:
print "Received", self._pprint(resp)
if expected and not isinstance(resp, expected):
raise CallException("Expected %s message, got %s message" % (expected.DESCRIPTOR.name, resp.DESCRIPTOR.name))
return resp
def sign_message(self, n, message):
@ -252,6 +255,7 @@ class TrezorClient(object):
#script_args=
)
'''
start = time.time()
try:

Loading…
Cancel
Save