diff --git a/cmd.py b/cmd.py index 920394779..276eb41d1 100755 --- a/cmd.py +++ b/cmd.py @@ -99,9 +99,6 @@ class Commands(object): def get_public_node(self, args): return self.client.get_public_node(args.n) - def get_serial_number(self, args): - return binascii.hexlify(self.client.get_serial_number()) - def set_label(self, args): return self.client.apply_settings(label=args.label) @@ -109,9 +106,14 @@ class Commands(object): return self.client.apply_settings(coin_shortcut=args.coin_shortcut) def load_device(self, args): - seed = ' '.join(args.seed) + if not args.mnemonic and not args.xprv: + raise Exception("Please provide mnemonic or xprv") - return self.client.load_device(seed, args.pin) + if args.mnemonic: + mnemonic = ' '.join(args.mnemonic) + return self.client.load_device_by_mnemonic(mnemonic, args.pin, args.passphrase_protection) + + return self.client.load_device_by_xprv(args.xprv, args.pin, args.passphrase_protection) def sign_message(self, args): return self.client.sign_message(args.n, args.message) @@ -134,7 +136,6 @@ class Commands(object): get_address.help = 'Get bitcoin address in base58 encoding' get_entropy.help = 'Get example entropy' get_features.help = 'Retrieve device features and settings' - get_serial_number.help = 'Get device\'s unique identifier' get_public_node.help = 'Get public node of given path' set_label.help = 'Set new wallet label' set_coin.help = 'Switch device to another crypto currency' @@ -166,8 +167,10 @@ class Commands(object): ) load_device.arguments = ( - (('-s', '--seed'), {'type': str, 'nargs': '+'}), - (('-n', '--pin'), {'type': str, 'default': ''}), + (('-m', '--mnemonic'), {'type': str, 'nargs': '+'}), + (('-x', '--xprv'), {'type': str}), + (('-p', '--pin'), {'type': str, 'default': ''}), + (('-r', '--passphrase-protection'), {'action': 'store_true', 'default': False}), ) sign_message.arguments = ( diff --git a/tests/common.py b/tests/common.py index 67367ab4b..9cd60e43e 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,7 +3,6 @@ import config from trezorlib.client import TrezorClient from trezorlib.debuglink import DebugLink -from trezorlib import proto class TrezorTest(unittest.TestCase): def setUp(self): @@ -19,7 +18,7 @@ class TrezorTest(unittest.TestCase): self.client.setup_debuglink(button=True, pin_correct=True) self.client.load_device( - seed=self.mnemonic1, + mnemonic=self.mnemonic1, pin=self.pin1) self.client.apply_settings(label='unit testing', coin_shortcut='BTC', language='english') diff --git a/tests/test_basic.py b/tests/test_basic.py index e25a47951..cfb61e9d0 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,7 +1,7 @@ import unittest import common -from trezorlib import proto +from trezorlib import messages_pb2 as messages ''' TODO: @@ -14,16 +14,16 @@ from trezorlib import proto class TestBasic(common.TrezorTest): def test_features(self): - features = self.client.call(proto.Initialize()) + features = self.client.call(messages.Initialize()) # Result is the same as reported by BitkeyClient class self.assertEqual(features, self.client.features) def test_ping(self): - ping = self.client.call(proto.Ping(message='ahoj!')) + ping = self.client.call(messages.Ping(message='ahoj!')) # Ping results in Success(message='Ahoj!') - self.assertEqual(ping, proto.Success(message='ahoj!')) + self.assertEqual(ping, messages.Success(message='ahoj!')) def test_uuid(self): uuid1 = self.client.get_device_id() diff --git a/trezorlib/client.py b/trezorlib/client.py index 22c2375ba..d3440d663 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -2,6 +2,7 @@ import os import time import ckd_public +import tools import messages_pb2 as proto import types_pb2 as types @@ -22,6 +23,8 @@ class CallException(Exception): class PinException(CallException): pass +PRIME_DERIVATION_FLAG = 0x80000000 + class TrezorClient(object): def __init__(self, transport, debuglink=None, @@ -39,6 +42,10 @@ class TrezorClient(object): def _get_local_entropy(self): return os.urandom(32) + + def _convert_prime(self, n): + # Convert minus signs to uint32 with flag + return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ] def init_device(self): self.features = self.call(proto.Initialize()) @@ -49,10 +56,11 @@ class TrezorClient(object): self.debuglink.transport.close() def get_public_node(self, n): - # print self.bip32_ckd(self.call(proto.GetPublicKey(address_n=n)).node, [2, ]) + n = self._convert_prime(n) return self.call(proto.GetPublicKey(address_n=n)).node def get_address(self, n): + n = self._convert_prime(n) return self.call(proto.GetAddress(address_n=n)).address def get_entropy(self, size): @@ -139,6 +147,7 @@ class TrezorClient(object): return resp def sign_message(self, n, message): + n = self._convert_prime(n) return self.call(proto.SignMessage(address_n=n, message=message)) def verify_message(self, address, signature, message): @@ -266,8 +275,42 @@ class TrezorClient(object): self.init_device() return isinstance(resp, proto.Success) - def load_device(self, seed, pin): - resp = self.call(proto.LoadDevice(seed=seed, pin=pin)) + def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection): + resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin, passphrase_protection=passphrase_protection)) + self.init_device() + return isinstance(resp, proto.Success) + + def load_device_by_xprv(self, xprv, pin, passphrase_protection): + if xprv[0:4] not in ('xprv', 'tprv'): + raise Exception("Unknown type of xprv") + + if len(xprv) < 100 and len(xprv) > 112: + raise Exception("Invalid length of xprv") + + node = types.HDNodeType() + data = tools.b58decode(xprv, None).encode('hex') + + if data[90:92] != '00': + raise Exception("Contain invalid private key") + + # version 0488ade4 + # depth 00 + # fingerprint 00000000 + # child_num 00000000 + # chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508 + # privkey 00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35 + # wtf is this? e77e9d71 + + node.version = int(data[0:8], 16) + node.depth = int(data[8:10], 16) + node.fingerprint = int(data[10:18], 16) + node.child_num = int(data[18:26], 16) + node.chain_code = data[26:90].decode('hex') + node.private_key = data[92:156].decode('hex') # skip 0x00 indicating privkey + print 'wtf is this?', len(data[156:]) + # FIXME + + resp = self.call(proto.LoadDevice(node=node, pin=pin, passphrase_protection=passphrase_protection)) self.init_device() return isinstance(resp, proto.Success)