1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-02 04:42:33 +00:00

add check for expected type of message responses

This commit is contained in:
Pavol Rusnak 2014-01-27 11:25:27 +01:00
parent 1089f58eb5
commit 24861a1b58
2 changed files with 63 additions and 59 deletions

View File

@ -30,8 +30,8 @@ class TestBasic(common.TrezorTest):
self.client.init_device() self.client.init_device()
uuid2 = self.client.get_device_id() uuid2 = self.client.get_device_id()
# UUID must be longer than 10 characters # UUID must be at least 12 characters
self.assertEqual(len(uuid1), 12) self.assertTrue(len(uuid1) >= 12)
# Every resulf of UUID must be the same # Every resulf of UUID must be the same
self.assertEqual(uuid1, uuid2) self.assertEqual(uuid1, uuid2)

View File

@ -88,7 +88,7 @@ class TrezorClient(object):
return path return path
def init_device(self): def init_device(self):
self.features = self.call(proto.Initialize()) self.features = self.call(proto.Initialize(), proto.Features)
def close(self): def close(self):
self.transport.close() self.transport.close()
@ -96,17 +96,17 @@ class TrezorClient(object):
self.debuglink.transport.close() self.debuglink.transport.close()
def get_public_node(self, n): def get_public_node(self, n):
return self.call(proto.GetPublicKey(address_n=n)).node return self.call(proto.GetPublicKey(address_n=n), proto.PublicKey).node
def get_address(self, coin_name, n): def get_address(self, coin_name, n):
n = self._convert_prime(n) n = self._convert_prime(n)
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name)).address return self.call(proto.GetAddress(address_n=n, coin_name=coin_name), proto.Address).address
def get_entropy(self, size): def get_entropy(self, size):
return self.call(proto.GetEntropy(size=size)).entropy return self.call(proto.GetEntropy(size=size), proto.Entropy).entropy
def ping(self, msg): def ping(self, msg):
return self.call(proto.Ping(message=msg)).message return self.call(proto.Ping(message=msg), proto.Success).message
def get_device_id(self): def get_device_id(self):
return self.features.device_id return self.features.device_id
@ -118,7 +118,7 @@ class TrezorClient(object):
if language: if language:
settings.language = language settings.language = language
out = self.call(settings).message out = self.call(settings, proto.Success).message
self.init_device() # Reload Features self.init_device() # Reload Features
return out return out
@ -130,7 +130,7 @@ class TrezorClient(object):
self.debug_button = button self.debug_button = button
self.debug_pin = pin_correct self.debug_pin = pin_correct
def call(self, msg): def call(self, msg, expected = None):
if self.debug: if self.debug:
print '----------------------' print '----------------------'
print "Sending", self._pprint(msg) print "Sending", self._pprint(msg)
@ -187,6 +187,9 @@ class TrezorClient(object):
if self.debug: if self.debug:
print "Received", self._pprint(resp) print "Received", self._pprint(resp)
if expected and not isinstance(resp, expected):
raise CallException("Expected %s message, got %s message" % (expected.DESCRIPTOR.name, resp.DESCRIPTOR.name))
return resp return resp
def sign_message(self, n, message): def sign_message(self, n, message):
@ -252,6 +255,7 @@ class TrezorClient(object):
#script_args= #script_args=
) )
''' '''
start = time.time() start = time.time()
try: try: