diff --git a/cmd.py b/cmd.py index 8b2bb8676..565449778 100755 --- a/cmd.py +++ b/cmd.py @@ -85,26 +85,28 @@ class Commands(object): pass def get_address(self, args): - return self.client.get_address(args.n) + address_n = self.client.expand_path(args.n) + return self.client.get_address(args.coin, address_n) def get_entropy(self, args): return binascii.hexlify(self.client.get_entropy(args.size)) def get_features(self, args): - return pb2json(self.client.features) + return self.client.features + + def list_coins(self, args): + return [ coin.coin_name for coin in self.client.features.coins ] def ping(self, args): return self.client.ping(args.msg) def get_public_node(self, args): - return self.client.get_public_node(args.n) + address_n = self.client.expand_path(args.n) + return self.client.get_public_node(address_n) def set_label(self, args): return self.client.apply_settings(label=args.label) - def set_coin(self, args): - return self.client.apply_settings(coin_shortcut=args.coin_shortcut) - def load_device(self, args): if not args.mnemonic and not args.xprv: raise Exception("Please provide mnemonic or xprv") @@ -120,7 +122,7 @@ class Commands(object): return self.client.reset_device(True, args.strength, args.passphrase, args.pin, args.label) def sign_message(self, args): - return self.client.sign_message(args.n, args.message) + return pb2json(self.client.sign_message(args.n, args.message), {'message': args.message}) def verify_message(self, args): return self.client.verify_message(args.address, args.signature, args.message) @@ -142,7 +144,7 @@ class Commands(object): get_features.help = 'Retrieve device features and settings' get_public_node.help = 'Get public node of given path' set_label.help = 'Set new wallet label' - set_coin.help = 'Switch device to another crypto currency' + list_coins.help = 'List all supported coin types by the device' load_device.help = 'Load custom configuration to the device' reset_device.help = 'Perform factory reset of the device and generate new seed' sign_message.help = 'Sign message using address of given path' @@ -150,7 +152,9 @@ class Commands(object): firmware_update.help = 'Upload new firmware to device (must be in bootloader mode)' get_address.arguments = ( - (('n',), {'metavar': 'N', 'type': int, 'nargs': '+'}), + (('-c', '--coin'), {'type': str, 'default': 'Bitcoin'}), + # (('n',), {'metavar': 'N', 'type': int, 'nargs': '+'}), + (('-n', '-address'), {'type': str}), ) get_entropy.arguments = ( @@ -159,6 +163,8 @@ class Commands(object): get_features.arguments = () + list_coins.arguments = () + ping.arguments = ( (('msg',), {'type': str}), ) @@ -167,10 +173,6 @@ class Commands(object): (('label',), {'type': str}), ) - set_coin.arguments = ( - (('coin_shortcut',), {'type': str}), - ) - load_device.arguments = ( (('-m', '--mnemonic'), {'type': str, 'nargs': '+'}), (('-x', '--xprv'), {'type': str}), @@ -198,7 +200,7 @@ class Commands(object): ) get_public_node.arguments = ( - (('n',), {'metavar': 'N', 'type': int, 'nargs': '+'}), + (('-n', '-address'), {'type': str}), ) firmware_update.arguments = ( diff --git a/trezorlib/client.py b/trezorlib/client.py index a5edb1a1b..82e0eb726 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -19,6 +19,9 @@ def show_input(input_text, message=None): def pin_func(input_text, message=None): return show_input(input_text, message) +def passphrase_func(input_text): + return show_input(input_text) + class CallException(Exception): pass @@ -30,13 +33,14 @@ PRIME_DERIVATION_FLAG = 0x80000000 class TrezorClient(object): def __init__(self, transport, debuglink=None, - message_func=show_message, input_func=show_input, pin_func=pin_func, debug=False): + message_func=show_message, input_func=show_input, pin_func=pin_func, passphrase_func=passphrase_func, debug=False): self.transport = transport self.debuglink = debuglink self.message_func = message_func self.input_func = input_func self.pin_func = pin_func + self.passphrase_func = passphrase_func self.debug = debug self.setup_debuglink() @@ -49,6 +53,26 @@ class TrezorClient(object): # Convert minus signs to uint32 with flag return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ] + def expand_path(self, n): + # Convert string of bip32 path to list of uint32 integers with prime flags + # 0/-1/1' -> [0, 0x80000001, 0x80000001] + n = n.split('/') + path = [] + for x in n: + prime = False + if '\'' in x: + x = x.replace('\'', '') + prime = True + if '-' in x: + prime = True + + if prime: + path.append(abs(int(x)) | PRIME_DERIVATION_FLAG) + else: + path.append(abs(int(x))) + + return path + def init_device(self): self.features = self.call(proto.Initialize()) @@ -58,12 +82,10 @@ class TrezorClient(object): self.debuglink.transport.close() def get_public_node(self, n): - n = self._convert_prime(n) return self.call(proto.GetPublicKey(address_n=n)).node - def get_address(self, n): - n = self._convert_prime(n) - return self.call(proto.GetAddress(address_n=n)).address + def get_address(self, coin_name, n): + return self.call(proto.GetAddress(address_n=n, coin_name=coin_name)).address def get_entropy(self, size): return self.call(proto.GetEntropy(size=size)).entropy @@ -128,6 +150,11 @@ class TrezorClient(object): msg2 = proto.PinMatrixAck(pin=pin) return self.call(msg2) + + if isinstance(resp, proto.PassphraseRequest): + passphrase = self.passphrase_func("Passphrase required: ") + msg2 = proto.PassphraseAck(passphrase=passphrase) + return self.call(msg2) finally: self.transport.session_end()