From d10686906167a65ee7a300f7bcf8f716a908bf7d Mon Sep 17 00:00:00 2001 From: matejcik Date: Wed, 18 Apr 2018 15:00:59 +0200 Subject: [PATCH] trezorlib: deprecate client.expand_path and move the staticmethod to an ordinary function tools.parse_path Also remove PRIME_DERIVATION_FLAG and move it to tools.HARDENED_FLAG --- trezorctl | 47 ++++++++++++----------- trezorlib/client.py | 41 +++----------------- trezorlib/tests/support/ckd_public.py | 4 +- trezorlib/tools.py | 54 ++++++++++++++++++++++++++- 4 files changed, 81 insertions(+), 65 deletions(-) diff --git a/trezorctl b/trezorctl index edb15fb23d..1fc7394cc4 100755 --- a/trezorctl +++ b/trezorctl @@ -23,18 +23,17 @@ import base64 import binascii import click -import functools import json import os import sys from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException, format_protobuf -from trezorlib.transport import get_transport, enumerate_devices, TransportException +from trezorlib.transport import get_transport, enumerate_devices from trezorlib import coins from trezorlib import messages as proto from trezorlib import protobuf -from trezorlib.ckd_public import PRIME_DERIVATION_FLAG from trezorlib import stellar +from trezorlib import tools class ChoiceType(click.Choice): @@ -449,7 +448,7 @@ def self_test(connect): @click.pass_obj def get_address(connect, coin, address, script_type, show_display): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) return client.get_address(coin, address_n, show_display, script_type=script_type) @@ -461,7 +460,7 @@ def get_address(connect, coin, address, script_type, show_display): @click.pass_obj def get_public_node(connect, coin, address, curve, show_display): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin) return { 'node': { @@ -502,7 +501,7 @@ def sign_tx(connect, coin): if address_n is None: pass - elif address_n[0] == (49 | PRIME_DERIVATION_FLAG): + elif address_n[0] == tools.H_(49): script_type = 'p2shsegwit' return script_type @@ -518,7 +517,7 @@ def sign_tx(connect, coin): if not prev: break prev_hash, prev_index = prev - address_n = click.prompt('BIP-32 path to derive the key', type=client.expand_path) + address_n = click.prompt('BIP-32 path to derive the key', type=tools.parse_path) amount = click.prompt('Input amount (satoshis)', type=int, default=0) sequence = click.prompt('Sequence Number to use (RBF opt-in enabled by default)', type=int, default=0xfffffffd) script_type = click.prompt('Input type', type=CHOICE_INPUT_SCRIPT_TYPE, default=default_script_type(address_n)) @@ -540,7 +539,7 @@ def sign_tx(connect, coin): address_n = None else: address = None - address_n = click.prompt('BIP-32 path (for change output)', type=client.expand_path, default='') + address_n = click.prompt('BIP-32 path (for change output)', type=tools.parse_path, default='') if not address_n: break amount = click.prompt('Amount to spend (satoshis)', type=int) @@ -581,7 +580,7 @@ def sign_tx(connect, coin): @click.pass_obj def sign_message(connect, coin, address, message, script_type): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) typemap = { 'address': proto.InputScriptType.SPENDADDRESS, 'segwit': proto.InputScriptType.SPENDWITNESS, @@ -613,7 +612,7 @@ def verify_message(connect, coin, address, signature, message): @click.pass_obj def ethereum_sign_message(connect, address, message): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) ret = client.ethereum_sign_message(address_n, message) output = { 'message': message, @@ -648,7 +647,7 @@ def ethereum_verify_message(connect, address, signature, message): @click.pass_obj def encrypt_keyvalue(connect, address, key, value): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) res = client.encrypt_keyvalue(address_n, key, value.encode()) return binascii.hexlify(res) @@ -660,7 +659,7 @@ def encrypt_keyvalue(connect, address, key, value): @click.pass_obj def decrypt_keyvalue(connect, address, key, value): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) return client.decrypt_keyvalue(address_n, key, binascii.unhexlify(value)) @@ -674,7 +673,7 @@ def decrypt_keyvalue(connect, address, key, value): def encrypt_message(connect, coin, display_only, address, pubkey, message): client = connect() pubkey = binascii.unhexlify(pubkey) - address_n = client.expand_path(address) + address_n = tools.parse_path(address) res = client.encrypt_message(pubkey, message, display_only, coin, address_n) return { 'nonce': binascii.hexlify(res.nonce), @@ -690,7 +689,7 @@ def encrypt_message(connect, coin, display_only, address, pubkey, message): @click.pass_obj def decrypt_message(connect, address, payload): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) payload = base64.b64decode(payload) nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:] return client.decrypt_message(address_n, nonce, message, msg_hmac) @@ -707,7 +706,7 @@ def decrypt_message(connect, address, payload): @click.pass_obj def ethereum_get_address(connect, address, show_display): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) address = client.ethereum_get_address(address_n, show_display) return '0x%s' % binascii.hexlify(address).decode() @@ -774,7 +773,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri to_address = ethereum_decode_hex(to) client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)).decode()) if gas_price is None or gas_limit is None or nonce is None or publish: @@ -836,7 +835,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri @click.pass_obj def nem_get_address(connect, address, network, show_display): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) return client.nem_get_address(address_n, network, show_display) @@ -847,7 +846,7 @@ def nem_get_address(connect, address, network, show_display): @click.pass_obj def nem_sign_tx(connect, address, file, broadcast): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) transaction = client.nem_sign_tx(address_n, json.load(file)) payload = { @@ -873,7 +872,7 @@ def nem_sign_tx(connect, address, file, broadcast): @click.pass_obj def lisk_get_address(connect, address, show_display): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) return client.lisk_get_address(address_n, show_display) @@ -883,7 +882,7 @@ def lisk_get_address(connect, address, show_display): @click.pass_obj def lisk_get_public_key(connect, address, show_display): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) res = client.lisk_get_public_key(address_n, show_display) output = { "public_key": binascii.hexlify(res.public_key).decode() @@ -897,7 +896,7 @@ def lisk_get_public_key(connect, address, show_display): @click.pass_obj def lisk_sign_message(connect, address, message): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) res = client.lisk_sign_message(address_n, message) output = { 'message': message, @@ -925,7 +924,7 @@ def lisk_verify_message(connect, pubkey, signature, message): @click.pass_obj def lisk_sign_tx(connect, address, file): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) transaction = client.lisk_sign_tx(address_n, json.load(file)) payload = { @@ -946,7 +945,7 @@ def lisk_sign_tx(connect, address, file): @click.pass_obj def cosi_commit(connect, address, data): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) return client.cosi_commit(address_n, binascii.unhexlify(data)) @@ -958,7 +957,7 @@ def cosi_commit(connect, address, data): @click.pass_obj def cosi_sign(connect, address, data, global_commitment, global_pubkey): client = connect() - address_n = client.expand_path(address) + address_n = tools.parse_path(address) return client.cosi_sign(address_n, binascii.unhexlify(data), binascii.unhexlify(global_commitment), binascii.unhexlify(global_pubkey)) diff --git a/trezorlib/client.py b/trezorlib/client.py index 0479cf10bb..eedc6225c4 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -34,7 +34,6 @@ from . import messages as proto from . import tools from . import mapping from . import nem -from .coins import slip44 from . import stellar from .debuglink import DebugLink from .protobuf import MessageType @@ -490,7 +489,6 @@ class DebugLinkMixin(object): class ProtocolMixin(object): - PRIME_DERIVATION_FLAG = 0x80000000 VENDORS = ('bitcointrezor.com', 'trezor.io') def __init__(self, state=None, *args, **kwargs): @@ -513,44 +511,15 @@ class ProtocolMixin(object): def _get_local_entropy(self): return os.urandom(32) - def _convert_prime(self, n): + @staticmethod + def _convert_prime(n: tools.Address) -> tools.Address: # Convert minus signs to uint32 with flag - return [int(abs(x) | self.PRIME_DERIVATION_FLAG) if x < 0 else x for x in n] + return [tools.H_(int(abs(x))) if x < 0 else x for x in n] @staticmethod def expand_path(n): - # Convert string of bip32 path to list of uint32 integers with prime flags - # 0/-1/1' -> [0, 0x80000001, 0x80000001] - if not n: - return [] - - n = n.split('/') - - # m/a/b/c => a/b/c - if n[0] == 'm': - n = n[1:] - - # coin_name/a/b/c => 44'/SLIP44_constant'/a/b/c - if n[0] in slip44: - n = ["44'", "%d'" % slip44[n[0]]] + n[1:] - - path = [] - for x in n: - prime = False - if x.endswith("'"): - x = x.replace('\'', '') - prime = True - if x.startswith('-'): - prime = True - - x = abs(int(x)) - - if prime: - x |= ProtocolMixin.PRIME_DERIVATION_FLAG - - path.append(x) - - return path + warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning) + return tools.parse_path(n) @expect(proto.PublicKey) def get_public_node(self, n, ecdsa_curve_name=None, show_display=False, coin_name=None): diff --git a/trezorlib/tests/support/ckd_public.py b/trezorlib/tests/support/ckd_public.py index e1ee5d8fa4..e377cf05a9 100644 --- a/trezorlib/tests/support/ckd_public.py +++ b/trezorlib/tests/support/ckd_public.py @@ -30,8 +30,6 @@ from ecdsa.ellipticcurve import Point, INFINITY from trezorlib import tools from trezorlib import messages -PRIME_DERIVATION_FLAG = 0x80000000 - def point_to_pubkey(point): order = SECP256k1.order @@ -61,7 +59,7 @@ def sec_to_public_pair(pubkey): def is_prime(n): - return (bool)(n & PRIME_DERIVATION_FLAG) + return bool(n & tools.HARDENED_FLAG) def fingerprint(pubkey): diff --git a/trezorlib/tools.py b/trezorlib/tools.py index 56e409dceb..8029cb1cbf 100644 --- a/trezorlib/tools.py +++ b/trezorlib/tools.py @@ -18,9 +18,21 @@ # along with this library. If not, see . import hashlib -import binascii import struct -import sys +from typing import NewType, List + +from .coins import slip44 + +HARDENED_FLAG = 1 << 31 + +Address = NewType('Address', List[int]) + + +def H_(x: int) -> int: + """ + Shortcut function that "hardens" a number in a BIP44 path. + """ + return x | HARDENED_FLAG def Hash(data): @@ -109,3 +121,41 @@ def b58decode(v, length): return None return result + + +def parse_path(nstr: str) -> Address: + """ + Convert BIP32 path string to list of uint32 integers with hardened flags. + Several conventions are supported to set the hardened flag: -1, 1', 1h + + e.g.: "0/1h/1" -> [0, 0x80000001, 1] + + :param nstr: path string + :return: list of integers + """ + if not nstr: + return [] + + n = nstr.split('/') + + # m/a/b/c => a/b/c + if n[0] == 'm': + n = n[1:] + + # coin_name/a/b/c => 44'/SLIP44_constant'/a/b/c + if n[0] in slip44: + coin_id = slip44[n[0]] + n[0:1] = ['44h', '{}h'.format(coin_id)] + + def str_to_harden(x: str) -> int: + if x.startswith('-'): + return H_(abs(int(x))) + elif x.endswith(('h', "'")): + return H_(int(x[:-1])) + else: + return int(x) + + try: + return list(str_to_harden(x) for x in n) + except Exception: + raise ValueError('Invalid BIP32 path', nstr)