1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-17 01:52:02 +00:00

add check for expected type of message responses

This commit is contained in:
Pavol Rusnak 2014-01-27 11:25:27 +01:00
parent 1089f58eb5
commit 24861a1b58
2 changed files with 63 additions and 59 deletions

View File

@ -12,27 +12,27 @@ from trezorlib import messages_pb2 as messages
''' '''
class TestBasic(common.TrezorTest): class TestBasic(common.TrezorTest):
def test_features(self): def test_features(self):
features = self.client.call(messages.Initialize()) features = self.client.call(messages.Initialize())
# Result is the same as reported by BitkeyClient class # Result is the same as reported by BitkeyClient class
self.assertEqual(features, self.client.features) self.assertEqual(features, self.client.features)
def test_ping(self): def test_ping(self):
ping = self.client.call(messages.Ping(message='ahoj!')) ping = self.client.call(messages.Ping(message='ahoj!'))
# Ping results in Success(message='Ahoj!') # Ping results in Success(message='Ahoj!')
self.assertEqual(ping, messages.Success(message='ahoj!')) self.assertEqual(ping, messages.Success(message='ahoj!'))
def test_uuid(self): def test_uuid(self):
uuid1 = self.client.get_device_id() uuid1 = self.client.get_device_id()
self.client.init_device() self.client.init_device()
uuid2 = self.client.get_device_id() uuid2 = self.client.get_device_id()
# UUID must be longer than 10 characters # UUID must be at least 12 characters
self.assertEqual(len(uuid1), 12) self.assertTrue(len(uuid1) >= 12)
# Every resulf of UUID must be the same # Every resulf of UUID must be the same
self.assertEqual(uuid1, uuid2) self.assertEqual(uuid1, uuid2)

View File

@ -11,7 +11,7 @@ from api_blockchain import BlockchainApi
def show_message(message): def show_message(message):
print "MESSAGE FROM DEVICE:", message print "MESSAGE FROM DEVICE:", message
def show_input(input_text, message=None): def show_input(input_text, message=None):
if message: if message:
print "QUESTION FROM DEVICE:", message print "QUESTION FROM DEVICE:", message
@ -32,21 +32,21 @@ class PinException(CallException):
PRIME_DERIVATION_FLAG = 0x80000000 PRIME_DERIVATION_FLAG = 0x80000000
class TrezorClient(object): class TrezorClient(object):
def __init__(self, transport, debuglink=None, def __init__(self, transport, debuglink=None,
message_func=show_message, input_func=show_input, message_func=show_message, input_func=show_input,
pin_func=pin_func, passphrase_func=passphrase_func, pin_func=pin_func, passphrase_func=passphrase_func,
blockchain_api=None, debug=False): blockchain_api=None, debug=False):
self.transport = transport self.transport = transport
self.debuglink = debuglink self.debuglink = debuglink
self.message_func = message_func self.message_func = message_func
self.input_func = input_func self.input_func = input_func
self.pin_func = pin_func self.pin_func = pin_func
self.passphrase_func = passphrase_func self.passphrase_func = passphrase_func
self.debug = debug self.debug = debug
if blockchain_api: if blockchain_api:
self.blockchain = blockchain_api self.blockchain = blockchain_api
else: else:
@ -54,14 +54,14 @@ class TrezorClient(object):
self.setup_debuglink() self.setup_debuglink()
self.init_device() self.init_device()
def _get_local_entropy(self): def _get_local_entropy(self):
return os.urandom(32) return os.urandom(32)
def _convert_prime(self, n): def _convert_prime(self, n):
# Convert minus signs to uint32 with flag # Convert minus signs to uint32 with flag
return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ] return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ]
def expand_path(self, n): def expand_path(self, n):
# Convert string of bip32 path to list of uint32 integers with prime flags # Convert string of bip32 path to list of uint32 integers with prime flags
# 0/-1/1' -> [0, 0x80000001, 0x80000001] # 0/-1/1' -> [0, 0x80000001, 0x80000001]
@ -79,7 +79,7 @@ class TrezorClient(object):
prime = True prime = True
x = abs(int(x)) x = abs(int(x))
if prime: if prime:
x |= PRIME_DERIVATION_FLAG x |= PRIME_DERIVATION_FLAG
@ -88,7 +88,7 @@ class TrezorClient(object):
return path return path
def init_device(self): def init_device(self):
self.features = self.call(proto.Initialize()) self.features = self.call(proto.Initialize(), proto.Features)
def close(self): def close(self):
self.transport.close() self.transport.close()
@ -96,17 +96,17 @@ class TrezorClient(object):
self.debuglink.transport.close() self.debuglink.transport.close()
def get_public_node(self, n): def get_public_node(self, n):
return self.call(proto.GetPublicKey(address_n=n)).node return self.call(proto.GetPublicKey(address_n=n), proto.PublicKey).node
def get_address(self, coin_name, n): def get_address(self, coin_name, n):
n = self._convert_prime(n) n = self._convert_prime(n)
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name)).address return self.call(proto.GetAddress(address_n=n, coin_name=coin_name), proto.Address).address
def get_entropy(self, size): def get_entropy(self, size):
return self.call(proto.GetEntropy(size=size)).entropy return self.call(proto.GetEntropy(size=size), proto.Entropy).entropy
def ping(self, msg): def ping(self, msg):
return self.call(proto.Ping(message=msg)).message return self.call(proto.Ping(message=msg), proto.Success).message
def get_device_id(self): def get_device_id(self):
return self.features.device_id return self.features.device_id
@ -118,7 +118,7 @@ class TrezorClient(object):
if language: if language:
settings.language = language settings.language = language
out = self.call(settings).message out = self.call(settings, proto.Success).message
self.init_device() # Reload Features self.init_device() # Reload Features
return out return out
@ -129,25 +129,25 @@ class TrezorClient(object):
def setup_debuglink(self, button=None, pin_correct=False): def setup_debuglink(self, button=None, pin_correct=False):
self.debug_button = button self.debug_button = button
self.debug_pin = pin_correct self.debug_pin = pin_correct
def call(self, msg): def call(self, msg, expected = None):
if self.debug: if self.debug:
print '----------------------' print '----------------------'
print "Sending", self._pprint(msg) print "Sending", self._pprint(msg)
try: try:
self.transport.session_begin() self.transport.session_begin()
self.transport.write(msg) self.transport.write(msg)
resp = self.transport.read_blocking() resp = self.transport.read_blocking()
if isinstance(resp, proto.ButtonRequest): if isinstance(resp, proto.ButtonRequest):
if self.debuglink and self.debug_button: if self.debuglink and self.debug_button:
print "Pressing button", self.debug_button print "Pressing button", self.debug_button
self.debuglink.press_button(self.debug_button) self.debuglink.press_button(self.debug_button)
return self.call(proto.ButtonAck()) return self.call(proto.ButtonAck())
if isinstance(resp, proto.PinMatrixRequest): if isinstance(resp, proto.PinMatrixRequest):
if self.debuglink: if self.debuglink:
if self.debug_pin == 1: if self.debug_pin == 1:
@ -161,9 +161,9 @@ class TrezorClient(object):
else: else:
pin = self.pin_func("PIN required: ", resp.message) pin = self.pin_func("PIN required: ", resp.message)
msg2 = proto.PinMatrixAck(pin=pin) msg2 = proto.PinMatrixAck(pin=pin)
return self.call(msg2) return self.call(msg2)
if isinstance(resp, proto.PassphraseRequest): if isinstance(resp, proto.PassphraseRequest):
passphrase = self.passphrase_func("Passphrase required: ") passphrase = self.passphrase_func("Passphrase required: ")
msg2 = proto.PassphraseAck(passphrase=passphrase) msg2 = proto.PassphraseAck(passphrase=passphrase)
@ -171,22 +171,25 @@ class TrezorClient(object):
finally: finally:
self.transport.session_end() self.transport.session_end()
if isinstance(resp, proto.Failure): if isinstance(resp, proto.Failure):
self.message_func(resp.message) self.message_func(resp.message)
if resp.code == types.Failure_ActionCancelled: if resp.code == types.Failure_ActionCancelled:
raise CallException("Action cancelled by user") raise CallException("Action cancelled by user")
elif resp.code in (types.Failure_PinInvalid, elif resp.code in (types.Failure_PinInvalid,
types.Failure_PinCancelled, types.Failure_PinExpected): types.Failure_PinCancelled, types.Failure_PinExpected):
raise PinException("PIN is invalid") raise PinException("PIN is invalid")
raise CallException(resp.code, resp.message) raise CallException(resp.code, resp.message)
if self.debug: if self.debug:
print "Received", self._pprint(resp) print "Received", self._pprint(resp)
if expected and not isinstance(resp, expected):
raise CallException("Expected %s message, got %s message" % (expected.DESCRIPTOR.name, resp.DESCRIPTOR.name))
return resp return resp
def sign_message(self, n, message): def sign_message(self, n, message):
@ -236,7 +239,7 @@ class TrezorClient(object):
''' '''
inputs: list of TxInput inputs: list of TxInput
outputs: list of TxOutput outputs: list of TxOutput
proto.TxInput(index=0, proto.TxInput(index=0,
address_n=0, address_n=0,
amount=0, amount=0,
@ -250,62 +253,63 @@ class TrezorClient(object):
amount=100000000, amount=100000000,
script_type=proto.PAYTOADDRESS, script_type=proto.PAYTOADDRESS,
#script_args= #script_args=
) )
''' '''
start = time.time() start = time.time()
try: try:
self.transport.session_begin() self.transport.session_begin()
# Prepare and send initial message # Prepare and send initial message
tx = proto.SignTx() tx = proto.SignTx()
tx.inputs_count = len(inputs) tx.inputs_count = len(inputs)
tx.outputs_count = len(outputs) tx.outputs_count = len(outputs)
res = self.call(tx) res = self.call(tx)
# Prepare structure for signatures # Prepare structure for signatures
signatures = [None]*len(inputs) signatures = [None]*len(inputs)
serialized_tx = '' serialized_tx = ''
counter = 0 counter = 0
while True: while True:
counter += 1 counter += 1
if isinstance(res, proto.Failure): if isinstance(res, proto.Failure):
raise CallException("Signing failed") raise CallException("Signing failed")
if not isinstance(res, proto.TxRequest): if not isinstance(res, proto.TxRequest):
raise CallException("Unexpected message") raise CallException("Unexpected message")
# If there's some part of signed transaction, let's add it # If there's some part of signed transaction, let's add it
if res.serialized_tx: if res.serialized_tx:
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx) print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
serialized_tx += res.serialized_tx serialized_tx += res.serialized_tx
if res.signed_index >= 0 and res.signature: if res.signed_index >= 0 and res.signature:
print "!!! SIGNED INPUT", res.signed_index print "!!! SIGNED INPUT", res.signed_index
signatures[res.signed_index] = res.signature signatures[res.signed_index] = res.signature
if res.request_index < 0: if res.request_index < 0:
# Device didn't ask for more information, finish workflow # Device didn't ask for more information, finish workflow
break break
# Device asked for one more information, let's process it. # Device asked for one more information, let's process it.
if res.request_type == types.TXOUTPUT: if res.request_type == types.TXOUTPUT:
res = self.call(outputs[res.request_index]) res = self.call(outputs[res.request_index])
continue continue
elif res.request_type == types.TXINPUT: elif res.request_type == types.TXINPUT:
print "REQUESTING", res.request_index print "REQUESTING", res.request_index
res = self.call(inputs[res.request_index]) res = self.call(inputs[res.request_index])
continue continue
finally: finally:
self.transport.session_end() self.transport.session_end()
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \ print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
(time.time() - start, counter, len(serialized_tx)) (time.time() - start, counter, len(serialized_tx))
return (signatures, serialized_tx) return (signatures, serialized_tx)
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language): def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
@ -327,7 +331,7 @@ class TrezorClient(object):
return isinstance(resp, proto.Success) return isinstance(resp, proto.Success)
def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language): def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language):
resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin, resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin,
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,