mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-17 01:52:02 +00:00
add check for expected type of message responses
This commit is contained in:
parent
1089f58eb5
commit
24861a1b58
@ -12,27 +12,27 @@ from trezorlib import messages_pb2 as messages
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
|
|
||||||
class TestBasic(common.TrezorTest):
|
class TestBasic(common.TrezorTest):
|
||||||
def test_features(self):
|
def test_features(self):
|
||||||
features = self.client.call(messages.Initialize())
|
features = self.client.call(messages.Initialize())
|
||||||
|
|
||||||
# Result is the same as reported by BitkeyClient class
|
# Result is the same as reported by BitkeyClient class
|
||||||
self.assertEqual(features, self.client.features)
|
self.assertEqual(features, self.client.features)
|
||||||
|
|
||||||
def test_ping(self):
|
def test_ping(self):
|
||||||
ping = self.client.call(messages.Ping(message='ahoj!'))
|
ping = self.client.call(messages.Ping(message='ahoj!'))
|
||||||
|
|
||||||
# Ping results in Success(message='Ahoj!')
|
# Ping results in Success(message='Ahoj!')
|
||||||
self.assertEqual(ping, messages.Success(message='ahoj!'))
|
self.assertEqual(ping, messages.Success(message='ahoj!'))
|
||||||
|
|
||||||
def test_uuid(self):
|
def test_uuid(self):
|
||||||
uuid1 = self.client.get_device_id()
|
uuid1 = self.client.get_device_id()
|
||||||
self.client.init_device()
|
self.client.init_device()
|
||||||
uuid2 = self.client.get_device_id()
|
uuid2 = self.client.get_device_id()
|
||||||
|
|
||||||
# UUID must be longer than 10 characters
|
# UUID must be at least 12 characters
|
||||||
self.assertEqual(len(uuid1), 12)
|
self.assertTrue(len(uuid1) >= 12)
|
||||||
|
|
||||||
# Every resulf of UUID must be the same
|
# Every resulf of UUID must be the same
|
||||||
self.assertEqual(uuid1, uuid2)
|
self.assertEqual(uuid1, uuid2)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from api_blockchain import BlockchainApi
|
|||||||
|
|
||||||
def show_message(message):
|
def show_message(message):
|
||||||
print "MESSAGE FROM DEVICE:", message
|
print "MESSAGE FROM DEVICE:", message
|
||||||
|
|
||||||
def show_input(input_text, message=None):
|
def show_input(input_text, message=None):
|
||||||
if message:
|
if message:
|
||||||
print "QUESTION FROM DEVICE:", message
|
print "QUESTION FROM DEVICE:", message
|
||||||
@ -32,21 +32,21 @@ class PinException(CallException):
|
|||||||
PRIME_DERIVATION_FLAG = 0x80000000
|
PRIME_DERIVATION_FLAG = 0x80000000
|
||||||
|
|
||||||
class TrezorClient(object):
|
class TrezorClient(object):
|
||||||
|
|
||||||
def __init__(self, transport, debuglink=None,
|
def __init__(self, transport, debuglink=None,
|
||||||
message_func=show_message, input_func=show_input,
|
message_func=show_message, input_func=show_input,
|
||||||
pin_func=pin_func, passphrase_func=passphrase_func,
|
pin_func=pin_func, passphrase_func=passphrase_func,
|
||||||
blockchain_api=None, debug=False):
|
blockchain_api=None, debug=False):
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.debuglink = debuglink
|
self.debuglink = debuglink
|
||||||
|
|
||||||
self.message_func = message_func
|
self.message_func = message_func
|
||||||
self.input_func = input_func
|
self.input_func = input_func
|
||||||
self.pin_func = pin_func
|
self.pin_func = pin_func
|
||||||
self.passphrase_func = passphrase_func
|
self.passphrase_func = passphrase_func
|
||||||
|
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
if blockchain_api:
|
if blockchain_api:
|
||||||
self.blockchain = blockchain_api
|
self.blockchain = blockchain_api
|
||||||
else:
|
else:
|
||||||
@ -54,14 +54,14 @@ class TrezorClient(object):
|
|||||||
|
|
||||||
self.setup_debuglink()
|
self.setup_debuglink()
|
||||||
self.init_device()
|
self.init_device()
|
||||||
|
|
||||||
def _get_local_entropy(self):
|
def _get_local_entropy(self):
|
||||||
return os.urandom(32)
|
return os.urandom(32)
|
||||||
|
|
||||||
def _convert_prime(self, n):
|
def _convert_prime(self, n):
|
||||||
# Convert minus signs to uint32 with flag
|
# Convert minus signs to uint32 with flag
|
||||||
return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ]
|
return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ]
|
||||||
|
|
||||||
def expand_path(self, n):
|
def expand_path(self, n):
|
||||||
# Convert string of bip32 path to list of uint32 integers with prime flags
|
# Convert string of bip32 path to list of uint32 integers with prime flags
|
||||||
# 0/-1/1' -> [0, 0x80000001, 0x80000001]
|
# 0/-1/1' -> [0, 0x80000001, 0x80000001]
|
||||||
@ -79,7 +79,7 @@ class TrezorClient(object):
|
|||||||
prime = True
|
prime = True
|
||||||
|
|
||||||
x = abs(int(x))
|
x = abs(int(x))
|
||||||
|
|
||||||
if prime:
|
if prime:
|
||||||
x |= PRIME_DERIVATION_FLAG
|
x |= PRIME_DERIVATION_FLAG
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ class TrezorClient(object):
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
self.features = self.call(proto.Initialize())
|
self.features = self.call(proto.Initialize(), proto.Features)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
@ -96,17 +96,17 @@ class TrezorClient(object):
|
|||||||
self.debuglink.transport.close()
|
self.debuglink.transport.close()
|
||||||
|
|
||||||
def get_public_node(self, n):
|
def get_public_node(self, n):
|
||||||
return self.call(proto.GetPublicKey(address_n=n)).node
|
return self.call(proto.GetPublicKey(address_n=n), proto.PublicKey).node
|
||||||
|
|
||||||
def get_address(self, coin_name, n):
|
def get_address(self, coin_name, n):
|
||||||
n = self._convert_prime(n)
|
n = self._convert_prime(n)
|
||||||
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name)).address
|
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name), proto.Address).address
|
||||||
|
|
||||||
def get_entropy(self, size):
|
def get_entropy(self, size):
|
||||||
return self.call(proto.GetEntropy(size=size)).entropy
|
return self.call(proto.GetEntropy(size=size), proto.Entropy).entropy
|
||||||
|
|
||||||
def ping(self, msg):
|
def ping(self, msg):
|
||||||
return self.call(proto.Ping(message=msg)).message
|
return self.call(proto.Ping(message=msg), proto.Success).message
|
||||||
|
|
||||||
def get_device_id(self):
|
def get_device_id(self):
|
||||||
return self.features.device_id
|
return self.features.device_id
|
||||||
@ -118,7 +118,7 @@ class TrezorClient(object):
|
|||||||
if language:
|
if language:
|
||||||
settings.language = language
|
settings.language = language
|
||||||
|
|
||||||
out = self.call(settings).message
|
out = self.call(settings, proto.Success).message
|
||||||
self.init_device() # Reload Features
|
self.init_device() # Reload Features
|
||||||
|
|
||||||
return out
|
return out
|
||||||
@ -129,25 +129,25 @@ class TrezorClient(object):
|
|||||||
def setup_debuglink(self, button=None, pin_correct=False):
|
def setup_debuglink(self, button=None, pin_correct=False):
|
||||||
self.debug_button = button
|
self.debug_button = button
|
||||||
self.debug_pin = pin_correct
|
self.debug_pin = pin_correct
|
||||||
|
|
||||||
def call(self, msg):
|
def call(self, msg, expected = None):
|
||||||
if self.debug:
|
if self.debug:
|
||||||
print '----------------------'
|
print '----------------------'
|
||||||
print "Sending", self._pprint(msg)
|
print "Sending", self._pprint(msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.transport.session_begin()
|
self.transport.session_begin()
|
||||||
|
|
||||||
self.transport.write(msg)
|
self.transport.write(msg)
|
||||||
resp = self.transport.read_blocking()
|
resp = self.transport.read_blocking()
|
||||||
|
|
||||||
if isinstance(resp, proto.ButtonRequest):
|
if isinstance(resp, proto.ButtonRequest):
|
||||||
if self.debuglink and self.debug_button:
|
if self.debuglink and self.debug_button:
|
||||||
print "Pressing button", self.debug_button
|
print "Pressing button", self.debug_button
|
||||||
self.debuglink.press_button(self.debug_button)
|
self.debuglink.press_button(self.debug_button)
|
||||||
|
|
||||||
return self.call(proto.ButtonAck())
|
return self.call(proto.ButtonAck())
|
||||||
|
|
||||||
if isinstance(resp, proto.PinMatrixRequest):
|
if isinstance(resp, proto.PinMatrixRequest):
|
||||||
if self.debuglink:
|
if self.debuglink:
|
||||||
if self.debug_pin == 1:
|
if self.debug_pin == 1:
|
||||||
@ -161,9 +161,9 @@ class TrezorClient(object):
|
|||||||
else:
|
else:
|
||||||
pin = self.pin_func("PIN required: ", resp.message)
|
pin = self.pin_func("PIN required: ", resp.message)
|
||||||
msg2 = proto.PinMatrixAck(pin=pin)
|
msg2 = proto.PinMatrixAck(pin=pin)
|
||||||
|
|
||||||
return self.call(msg2)
|
return self.call(msg2)
|
||||||
|
|
||||||
if isinstance(resp, proto.PassphraseRequest):
|
if isinstance(resp, proto.PassphraseRequest):
|
||||||
passphrase = self.passphrase_func("Passphrase required: ")
|
passphrase = self.passphrase_func("Passphrase required: ")
|
||||||
msg2 = proto.PassphraseAck(passphrase=passphrase)
|
msg2 = proto.PassphraseAck(passphrase=passphrase)
|
||||||
@ -171,22 +171,25 @@ class TrezorClient(object):
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.transport.session_end()
|
self.transport.session_end()
|
||||||
|
|
||||||
if isinstance(resp, proto.Failure):
|
if isinstance(resp, proto.Failure):
|
||||||
self.message_func(resp.message)
|
self.message_func(resp.message)
|
||||||
|
|
||||||
if resp.code == types.Failure_ActionCancelled:
|
if resp.code == types.Failure_ActionCancelled:
|
||||||
raise CallException("Action cancelled by user")
|
raise CallException("Action cancelled by user")
|
||||||
|
|
||||||
elif resp.code in (types.Failure_PinInvalid,
|
elif resp.code in (types.Failure_PinInvalid,
|
||||||
types.Failure_PinCancelled, types.Failure_PinExpected):
|
types.Failure_PinCancelled, types.Failure_PinExpected):
|
||||||
raise PinException("PIN is invalid")
|
raise PinException("PIN is invalid")
|
||||||
|
|
||||||
raise CallException(resp.code, resp.message)
|
raise CallException(resp.code, resp.message)
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
print "Received", self._pprint(resp)
|
print "Received", self._pprint(resp)
|
||||||
|
|
||||||
|
if expected and not isinstance(resp, expected):
|
||||||
|
raise CallException("Expected %s message, got %s message" % (expected.DESCRIPTOR.name, resp.DESCRIPTOR.name))
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def sign_message(self, n, message):
|
def sign_message(self, n, message):
|
||||||
@ -236,7 +239,7 @@ class TrezorClient(object):
|
|||||||
'''
|
'''
|
||||||
inputs: list of TxInput
|
inputs: list of TxInput
|
||||||
outputs: list of TxOutput
|
outputs: list of TxOutput
|
||||||
|
|
||||||
proto.TxInput(index=0,
|
proto.TxInput(index=0,
|
||||||
address_n=0,
|
address_n=0,
|
||||||
amount=0,
|
amount=0,
|
||||||
@ -250,62 +253,63 @@ class TrezorClient(object):
|
|||||||
amount=100000000,
|
amount=100000000,
|
||||||
script_type=proto.PAYTOADDRESS,
|
script_type=proto.PAYTOADDRESS,
|
||||||
#script_args=
|
#script_args=
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.transport.session_begin()
|
self.transport.session_begin()
|
||||||
|
|
||||||
# Prepare and send initial message
|
# Prepare and send initial message
|
||||||
tx = proto.SignTx()
|
tx = proto.SignTx()
|
||||||
tx.inputs_count = len(inputs)
|
tx.inputs_count = len(inputs)
|
||||||
tx.outputs_count = len(outputs)
|
tx.outputs_count = len(outputs)
|
||||||
res = self.call(tx)
|
res = self.call(tx)
|
||||||
|
|
||||||
# Prepare structure for signatures
|
# Prepare structure for signatures
|
||||||
signatures = [None]*len(inputs)
|
signatures = [None]*len(inputs)
|
||||||
serialized_tx = ''
|
serialized_tx = ''
|
||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
while True:
|
while True:
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
if isinstance(res, proto.Failure):
|
if isinstance(res, proto.Failure):
|
||||||
raise CallException("Signing failed")
|
raise CallException("Signing failed")
|
||||||
|
|
||||||
if not isinstance(res, proto.TxRequest):
|
if not isinstance(res, proto.TxRequest):
|
||||||
raise CallException("Unexpected message")
|
raise CallException("Unexpected message")
|
||||||
|
|
||||||
# If there's some part of signed transaction, let's add it
|
# If there's some part of signed transaction, let's add it
|
||||||
if res.serialized_tx:
|
if res.serialized_tx:
|
||||||
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
|
print "!!! RECEIVED PART OF SERIALIED TX (%d BYTES)" % len(res.serialized_tx)
|
||||||
serialized_tx += res.serialized_tx
|
serialized_tx += res.serialized_tx
|
||||||
|
|
||||||
if res.signed_index >= 0 and res.signature:
|
if res.signed_index >= 0 and res.signature:
|
||||||
print "!!! SIGNED INPUT", res.signed_index
|
print "!!! SIGNED INPUT", res.signed_index
|
||||||
signatures[res.signed_index] = res.signature
|
signatures[res.signed_index] = res.signature
|
||||||
|
|
||||||
if res.request_index < 0:
|
if res.request_index < 0:
|
||||||
# Device didn't ask for more information, finish workflow
|
# Device didn't ask for more information, finish workflow
|
||||||
break
|
break
|
||||||
|
|
||||||
# Device asked for one more information, let's process it.
|
# Device asked for one more information, let's process it.
|
||||||
if res.request_type == types.TXOUTPUT:
|
if res.request_type == types.TXOUTPUT:
|
||||||
res = self.call(outputs[res.request_index])
|
res = self.call(outputs[res.request_index])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif res.request_type == types.TXINPUT:
|
elif res.request_type == types.TXINPUT:
|
||||||
print "REQUESTING", res.request_index
|
print "REQUESTING", res.request_index
|
||||||
res = self.call(inputs[res.request_index])
|
res = self.call(inputs[res.request_index])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.transport.session_end()
|
self.transport.session_end()
|
||||||
|
|
||||||
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
|
print "SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" % \
|
||||||
(time.time() - start, counter, len(serialized_tx))
|
(time.time() - start, counter, len(serialized_tx))
|
||||||
|
|
||||||
return (signatures, serialized_tx)
|
return (signatures, serialized_tx)
|
||||||
|
|
||||||
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
|
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
|
||||||
@ -327,7 +331,7 @@ class TrezorClient(object):
|
|||||||
|
|
||||||
|
|
||||||
return isinstance(resp, proto.Success)
|
return isinstance(resp, proto.Success)
|
||||||
|
|
||||||
def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language):
|
def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language):
|
||||||
resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin,
|
resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin,
|
||||||
passphrase_protection=passphrase_protection,
|
passphrase_protection=passphrase_protection,
|
||||||
|
Loading…
Reference in New Issue
Block a user