From 558f61e6359bba3da5c670f7f0eb7a186822de18 Mon Sep 17 00:00:00 2001 From: slush0 Date: Thu, 13 Feb 2014 16:46:21 +0100 Subject: [PATCH] Heavily refactored TrezorClient --- trezorlib/client.py | 608 +++++++++++++++++++++++++---------------- trezorlib/debuglink.py | 3 + 2 files changed, 380 insertions(+), 231 deletions(-) diff --git a/trezorlib/client.py b/trezorlib/client.py index 60eb1c08f8..66c53506f5 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -1,32 +1,18 @@ import os -import time import binascii import hashlib import tools import messages_pb2 as proto import types_pb2 as types -from api_blockchain import BlockchainApi +from trezorlib.debuglink import DebugLink # monkeypatching: text formatting of protobuf messages tools.monkeypatch_google_protobuf_text_format() -def show_message(message): - print "MESSAGE FROM DEVICE:", message - -def show_input(input_text, message=None): - if message: - print "QUESTION FROM DEVICE:", message - return raw_input(input_text) - -def pin_func(input_text, message=None): - return show_input(input_text, message) - -def passphrase_func(input_text): - return show_input(input_text) - -def word_func(): - return raw_input("Enter one word of mnemonic: ") +def get_buttonrequest_value(code): + # Converts integer code to its string representation of ButtonRequestType + return [ k for k, v in types.ButtonRequestType.items() if v == code][0] class CallException(Exception): def __init__(self, code, message): @@ -36,39 +22,189 @@ class CallException(Exception): class PinException(CallException): pass -PRIME_DERIVATION_FLAG = 0x80000000 +class field(object): + # Decorator extracts single value from + # protobuf object. If the field is not + # present, raises an exception. + def __init__(self, field): + self.field = field -class TrezorClient(object): + def __call__(self, f): + def wrapped_f(*args, **kwargs): + ret = f(*args, **kwargs) + ret.HasField(self.field) + return getattr(ret, self.field) + return wrapped_f - def __init__(self, transport, debuglink=None, - message_func=show_message, input_func=show_input, - pin_func=pin_func, passphrase_func=passphrase_func, - word_func=word_func, blockchain_api=None, debug=False): +class expect(object): + # Decorator checks if the method + # returned one of expected protobuf messages + # or raises an exception + def __init__(self, *expected): + self.expected = expected + + def __call__(self, f): + def wrapped_f(*args, **kwargs): + ret = f(*args, **kwargs) + if not isinstance(ret, self.expected): + raise Exception("Got %s, expected %s" % (ret.__class__, self.expected)) + return ret + return wrapped_f + +class BaseClient(object): + # Implements very basic layer of sending raw protobuf + # messages to device and getting its response back. + def __init__(self, transport, *args, **kwargs): self.transport = transport - self.debuglink = debuglink + super(BaseClient, self).__init__(*args, **kwargs) - self.message_func = message_func - self.input_func = input_func - self.pin_func = pin_func - self.passphrase_func = passphrase_func - self.word_func = word_func + def call(self, msg): + try: + self.transport.session_begin() - self.debug = debug + self.transport.write(msg) + resp = self.transport.read_blocking() - if blockchain_api: - self.blockchain = blockchain_api + handler_name = "callback_%s" % resp.__class__.__name__ + handler = getattr(self, handler_name, None) + + if handler != None: + msg = handler(resp) + if msg == None: + raise Exception("Callback %s must return protobug message, not None" % handler) + + resp = self.call(msg) + + finally: + self.transport.session_end() + + return resp + + def callback_Failure(self, msg): + if msg.code in (types.Failure_PinInvalid, + types.Failure_PinCancelled, types.Failure_PinExpected): + raise PinException(msg.code, msg.message) + + raise CallException(msg.code, msg.message) + + def close(self): + self.transport.close() + +class TextUIMixin(object): + # This class demonstrates easy test-based UI + # integration between the device and wallet. + # You can implement similar functionality + # by implementing your own GuiMixin with + # graphical widgets for every type of these callbacks. + + def callback_ButtonRequest(self, msg): + print "Sending ButtonAck for %s " % get_buttonrequest_value(msg.code) + return proto.ButtonAck() + + def callback_PinMatrixRequest(self, msg): + pin = raw_input("PIN required: %s " % msg.message) + return proto.PinMatrixAck(pin=pin) + + def callback_PassphraseRequest(self, msg): + passphrase = raw_input("Passphrase required: %s " % msg.message) + return proto.PassphraseAck(passphrase=passphrase) + + def callback_WordRequest(self, msg): + word = raw_input("Enter one word of mnemonic: ") + return proto.WordAck(word=word) + +class DebugLinkMixin(object): + # This class implements automatic responses + # and other functionality for unit tests + # for various callbacks, created in order + # to automatically pass unit tests. + # + # This mixing should be used only for purposes + # of unit testing, because it will fail to work + # without special DebugLink interface provided + # by the device. + + def __init__(self, *args, **kwargs): + super(DebugLinkMixin, self).__init__(*args, **kwargs) + self.debug = None + + # Always press Yes and provide correct pin + self.setup_debuglink(True, True) + + # Do not expect any specific ButtonRequest + self.set_expected_buttonrequests(None) + + def close(self): + super(DebugLinkMixin, self).close() + if self.debug: + self.debug.close() + + def set_debuglink(self, debug_transport): + self.debug = DebugLink(debug_transport) + + def set_expected_buttonrequests(self, expected): + self.expected_buttonrequests = expected + + def setup_debuglink(self, button, pin_correct): + self.button = button # True -> YES button, False -> NO button + self.pin_correct = pin_correct + + def callback_ButtonRequest(self, msg): + if self.expected_buttonrequests != None: + try: + expected = self.expected_buttonrequests.pop(0) + if msg.code != expected: + raise CallException(types.Failure_Other, + "Expected %s, got %s" % \ + (get_buttonrequest_value(expected), + get_buttonrequest_value(msg.code))) + except IndexError: + raise CallException(types.Failure_Other, + "Got %s, but no ButtonRequest has been expected" % \ + get_buttonrequest_value(msg.code)) + + print "ButtonRequest code:", get_buttonrequest_value(msg.code) + + print "Pressing button", self.button + self.debug.press_button(self.button) + return proto.ButtonAck() + + def callback_PinMatrixRequest(self, msg): + if self.pin_correct: + pin = self.debug.read_pin_encoded() else: - self.blockchain = BlockchainApi() + pin = '444222' + return proto.PinMatrixAck(pin=pin) - self.setup_debuglink() + def callback_PassphraseRequest(self, msg): + pass + + def callback_WordRequest(self, msg): + pass + +class ProtocolMixin(object): + PRIME_DERIVATION_FLAG = 0x80000000 + + def __init__(self, *args, **kwargs): + super(ProtocolMixin, self).__init__() # *args, **kwargs) self.init_device() + + def get_tx_func_placeholder(txhash): + raise Exception("Please call set_tx_func() first.") + self.get_tx_func = get_tx_func_placeholder + + def set_tx_func(self, tx_func): + self.get_tx_func = tx_func + + def init_device(self): + self.features = expect(proto.Features)(self.call)(proto.Initialize()) def _get_local_entropy(self): return os.urandom(32) def _convert_prime(self, n): # Convert minus signs to uint32 with flag - return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ] + return [ int(abs(x) | self.PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ] def expand_path(self, n): # Convert string of bip32 path to list of uint32 integers with prime flags @@ -89,40 +225,42 @@ class TrezorClient(object): x = abs(int(x)) if prime: - x |= PRIME_DERIVATION_FLAG + x |= self.PRIME_DERIVATION_FLAG path.append(x) return path - def init_device(self): - self.features = self.call(proto.Initialize(), proto.Features) - - def close(self): - self.transport.close() - if self.debuglink: - self.debuglink.transport.close() - + @field('node') + @expect(proto.PublicKey) def get_public_node(self, n): - return self.call(proto.GetPublicKey(address_n=n), proto.PublicKey).node + return self.call(proto.GetPublicKey(address_n=n)) + @field('address') + @expect(proto.Address) def get_address(self, coin_name, n): n = self._convert_prime(n) - return self.call(proto.GetAddress(address_n=n, coin_name=coin_name), proto.Address).address + return self.call(proto.GetAddress(address_n=n, coin_name=coin_name)) + @field('entropy') + @expect(proto.Entropy) def get_entropy(self, size): - return self.call(proto.GetEntropy(size=size), proto.Entropy).entropy + return self.call(proto.GetEntropy(size=size)) + @field('message') + @expect(proto.Success) def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False): msg = proto.Ping(message=msg, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection) - return self.call(msg, proto.Success).message + return self.call(msg) def get_device_id(self): return self.features.device_id + @field('message') + @expect(proto.Success) def apply_settings(self, label=None, language=None): settings = proto.ApplySettings() if label != None: @@ -130,27 +268,203 @@ class TrezorClient(object): if language: settings.language = language - out = self.call(settings, proto.Success).message - self.init_device() # Reload Features - + out = self.call(settings) + self.init_device() # Reload Features return out + @field('message') + @expect(proto.Success) def change_pin(self, remove=False): ret = self.call(proto.ChangePin(remove=remove)) self.init_device() # Re-read features return ret + @expect(proto.MessageSignature) + def sign_message(self, n, message): + n = self._convert_prime(n) + return self.call(proto.SignMessage(address_n=n, message=message)) + + def verify_message(self, address, signature, message): + try: + resp = self.call(proto.VerifyMessage(address=address, signature=signature, message=message)) + except CallException as e: + resp = e + if isinstance(resp, proto.Success): + return True + return False + + @field('tx_size') + @expect(proto.TxSize) + def estimate_tx_size(self, coin_name, inputs, outputs): + msg = proto.EstimateTxSize() + msg.coin_name = coin_name + msg.inputs_count = len(inputs) + msg.outputs_count = len(outputs) + return self.call(msg) + + def _prepare_simple_sign_tx(self, coin_name, inputs, outputs): + msg = proto.SimpleSignTx() + msg.coin_name = coin_name + msg.inputs.extend(inputs) + msg.outputs.extend(outputs) + + known_hashes = [] + for inp in inputs: + if inp.prev_hash in known_hashes: + continue + + tx = msg.transactions.add() + tx.CopyFrom(self.get_tx_func(binascii.hexlify(inp.prev_hash))) + known_hashes.append(inp.prev_hash) + + return msg + + @field('serialized_tx') + @expect(proto.TxRequest) + def simple_sign_tx(self, coin_name, inputs, outputs): + # TODO Deserialize tx and check if inputs/outputs fits + msg = self._prepare_simple_sign_tx(coin_name, inputs, outputs) + return self.call(msg) + + def sign_tx(self, coin_name, inputs, outputs): + # Temporary solution, until streaming is implemented in the firmware + return self.simple_sign_tx(coin_name, inputs, outputs) + + @field('message') + @expect(proto.Success) + def wipe_device(self): + ret = self.call(proto.WipeDevice()) + self.init_device() + return ret + + @field('message') + @expect(proto.Success) + def recovery_device(self, word_count, passphrase_protection, pin_protection, label, language): + if self.features.initialized: + raise Exception("Device is initialized already. Call wipe_device() and try again.") + + if word_count not in (12, 18, 24): + raise Exception("Invalid word count. Use 12/18/24") + + res = self.call(proto.RecoveryDevice(word_count=int(word_count), + passphrase_protection=bool(passphrase_protection), + pin_protection=bool(pin_protection), + label=label, + language=language, + enforce_wordlist=True)) + + self.init_device() + return res + + @field('message') + @expect(proto.Success) + def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language): + if self.features.initialized: + raise Exception("Device is initialized already. Call wipe_device() and try again.") + + # Begin with device reset workflow + msg = proto.ResetDevice(display_random=display_random, + strength=strength, + language=language, + passphrase_protection=bool(passphrase_protection), + pin_protection=bool(pin_protection), + label=label) + + resp = self.call(msg) + if not isinstance(resp, proto.EntropyRequest): + raise Exception("Invalid response, expected EntropyRequest") + + external_entropy = self._get_local_entropy() + print "Computer generated entropy:", binascii.hexlify(external_entropy) + return self.call(proto.EntropyAck(entropy=external_entropy)) + + @field('message') + @expect(proto.Success) + def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language): + if self.features.initialized: + raise Exception("Device is initialized already. Call wipe_device() and try again.") + + resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin, + passphrase_protection=passphrase_protection, + language=language, + label=label)) + self.init_device() + return resp + + @field('message') + @expect(proto.Success) + def load_device_by_xprv(self, xprv, pin, passphrase_protection, label): + if self.features.initialized: + raise Exception("Device is initialized already. Call wipe_device() and try again.") + + if xprv[0:4] not in ('xprv', 'tprv'): + raise Exception("Unknown type of xprv") + + if len(xprv) < 100 and len(xprv) > 112: + raise Exception("Invalid length of xprv") + + node = types.HDNodeType() + data = tools.b58decode(xprv, None).encode('hex') + + if data[90:92] != '00': + raise Exception("Contain invalid private key") + + checksum = hashlib.sha256(hashlib.sha256(binascii.unhexlify(data[:156])).digest()).hexdigest()[:8] + if checksum != data[156:]: + raise Exception("Checksum doesn't match") + + # version 0488ade4 + # depth 00 + # fingerprint 00000000 + # child_num 00000000 + # chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508 + # privkey 00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35 + # checksum e77e9d71 + + node.version = int(data[0:8], 16) + node.depth = int(data[8:10], 16) + node.fingerprint = int(data[10:18], 16) + node.child_num = int(data[18:26], 16) + node.chain_code = data[26:90].decode('hex') + node.private_key = data[92:156].decode('hex') # skip 0x00 indicating privkey + + resp = self.call(proto.LoadDevice(node=node, + pin=pin, + passphrase_protection=passphrase_protection, + language='english', + label=label)) + self.init_device() + return resp + + def firmware_update(self, fp): + if self.features.bootloader_mode == False: + raise Exception("Device must be in bootloader mode") + + resp = self.call(proto.FirmwareErase()) + if isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError: + return False + + resp = self.call(proto.FirmwareUpload(payload=fp.read())) + if isinstance(resp, proto.Success): + return True + + elif isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError: + return False + + raise Exception("Unexpected result " % resp) + +class TrezorClient(BaseClient, ProtocolMixin, TextUIMixin): + pass + +class TrezorDebugClient(BaseClient, ProtocolMixin, DebugLinkMixin): + pass + +''' +class TrezorClient(object): def _pprint(self, msg): ser = msg.SerializeToString() return "<%s> (%d bytes):\n%s" % (msg.__class__.__name__, len(ser), msg) - def setup_debuglink(self, button=None, pin_correct=False): - self.debug_button = button - self.debug_pin = pin_correct - - def _get_buttonrequest_value(self, code): - return [ k for k, v in types.ButtonRequestType.items() if v == code][0] - def call(self, msg, expected=None, expected_buttonrequests=None): # TODO split this into normal and debug mode if self.debug: @@ -201,7 +515,7 @@ class TrezorClient(object): if isinstance(resp, proto.PassphraseRequest): passphrase = self.passphrase_func("Passphrase required: ") - msg2 = proto.PassphraseAck(passphrase=passphrase) + ms(object)g2 = proto.PassphraseAck(passphrase=passphrase) return self.call(msg2, expected=expected, expected_buttonrequests=expected_buttonrequests) finally: @@ -229,55 +543,8 @@ class TrezorClient(object): return resp - def sign_message(self, n, message): - n = self._convert_prime(n) - return self.call(proto.SignMessage(address_n=n, message=message)) - - def verify_message(self, address, signature, message): - try: - resp = self.call(proto.VerifyMessage(address=address, signature=signature, message=message)) - if isinstance(resp, proto.Success): - return True - except CallException: - pass - - return False - - def estimate_tx_size(self, coin_name, inputs, outputs): - msg = proto.EstimateTxSize() - msg.coin_name = coin_name - msg.inputs_count = len(inputs) - msg.outputs_count = len(outputs) - res = self.call(msg) - return res.tx_size - - def _prepare_simple_sign_tx(self, coin_name, inputs, outputs): - msg = proto.SimpleSignTx() - msg.coin_name = coin_name - msg.inputs.extend(inputs) - msg.outputs.extend(outputs) - - known_hashes = [] - for inp in inputs: - if inp.prev_hash in known_hashes: - continue - - tx = msg.transactions.add() - tx.CopyFrom(self.blockchain.get_tx(binascii.hexlify(inp.prev_hash))) - known_hashes.append(inp.prev_hash) - - return msg - - def simple_sign_tx(self, coin_name, inputs, outputs): - msg = self._prepare_simple_sign_tx(coin_name, inputs, outputs) - return self.call(msg) - - def sign_tx(self, coin_name, inputs, outputs): - # Temporary solution, until streaming is implemented in the firmware - return self.simple_sign_tx(coin_name, inputs, outputs) - def _sign_tx(self, coin_name, inputs, outputs): - ''' + '' inputs: list of TxInput outputs: list of TxOutput @@ -295,7 +562,7 @@ class TrezorClient(object): script_type=proto.PAYTOADDRESS, #script_args= ) - ''' + '' start = time.time() @@ -352,125 +619,4 @@ class TrezorClient(object): (time.time() - start, counter, len(serialized_tx)) return (signatures, serialized_tx) - - def wipe_device(self): - ret = self.call(proto.WipeDevice()) - self.init_device() - return ret - - def recovery_device(self, word_count, passphrase_protection, pin_protection, label, language): - if word_count not in (12, 18, 24): - raise Exception("Invalid word count. Use 12/18/24") - - res = self.call(proto.RecoveryDevice(word_count=int(word_count), - passphrase_protection=bool(passphrase_protection), - pin_protection=bool(pin_protection), - label=label, - language=language, - enforce_wordlist=True)) - - while isinstance(res, proto.WordRequest): - word = self.word_func() - res = self.call(proto.WordAck(word=word)) - - if not isinstance(res, proto.Success): - raise Exception("Recovery device failed") - - self.init_device() - return True - - def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language): - if self.features.initialized: - raise Exception("Device is initialized already. Call wipe_device() and try again.") - - # Begin with device reset workflow - msg = proto.ResetDevice(display_random=display_random, - strength=strength, - language=language, - passphrase_protection=bool(passphrase_protection), - pin_protection=bool(pin_protection), - label=label) - - resp = self.call(msg) - if not isinstance(resp, proto.EntropyRequest): - raise Exception("Invalid response, expected EntropyRequest") - - external_entropy = self._get_local_entropy() - print "Computer generated entropy:", binascii.hexlify(external_entropy) - resp = self.call(proto.EntropyAck(entropy=external_entropy)) - - - return isinstance(resp, proto.Success) - - def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language): - if self.features.initialized: - raise Exception("Device is initialized already. Call wipe_device() and try again.") - - resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin, - passphrase_protection=passphrase_protection, - language=language, - label=label)) - self.init_device() - return isinstance(resp, proto.Success) - - def load_device_by_xprv(self, xprv, pin, passphrase_protection, label): - if self.features.initialized: - raise Exception("Device is initialized already. Call wipe_device() and try again.") - - if xprv[0:4] not in ('xprv', 'tprv'): - raise Exception("Unknown type of xprv") - - if len(xprv) < 100 and len(xprv) > 112: - raise Exception("Invalid length of xprv") - - node = types.HDNodeType() - data = tools.b58decode(xprv, None).encode('hex') - - if data[90:92] != '00': - raise Exception("Contain invalid private key") - - checksum = hashlib.sha256(hashlib.sha256(binascii.unhexlify(data[:156])).digest()).hexdigest()[:8] - if checksum != data[156:]: - raise Exception("Checksum doesn't match") - - # version 0488ade4 - # depth 00 - # fingerprint 00000000 - # child_num 00000000 - # chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508 - # privkey 00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35 - # checksum e77e9d71 - - node.version = int(data[0:8], 16) - node.depth = int(data[8:10], 16) - node.fingerprint = int(data[10:18], 16) - node.child_num = int(data[18:26], 16) - node.chain_code = data[26:90].decode('hex') - node.private_key = data[92:156].decode('hex') # skip 0x00 indicating privkey - - resp = self.call(proto.LoadDevice(node=node, - pin=pin, - passphrase_protection=passphrase_protection, - language='english', - label=label)) - self.init_device() - return isinstance(resp, proto.Success) - - def firmware_update(self, fp): - if self.features.bootloader_mode == False: - raise Exception("Device must be in bootloader mode") - - resp = self.call(proto.FirmwareErase()) - if isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError: - return False - - resp = self.call(proto.FirmwareUpload(payload=fp.read())) - if isinstance(resp, proto.Success): - return True - - elif isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError: - return False - - raise Exception("Unexpected result " % resp) - -# class TrezorDebugClient(TrezorClient): +''' diff --git a/trezorlib/debuglink.py b/trezorlib/debuglink.py index 7bc1d35aaa..9021cc81e1 100644 --- a/trezorlib/debuglink.py +++ b/trezorlib/debuglink.py @@ -14,6 +14,9 @@ class DebugLink(object): self.pin_func = pin_func self.button_func = button_func + def close(self): + self.transport.close() + def read_pin(self): self.transport.write(proto.DebugLinkGetState()) obj = self.transport.read_blocking()