1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 12:00:59 +00:00

trezorlib: deprecate client.expand_path and move the staticmethod

to an ordinary function tools.parse_path

Also remove PRIME_DERIVATION_FLAG and move it to tools.HARDENED_FLAG
This commit is contained in:
matejcik 2018-04-18 15:00:59 +02:00
parent 4f66b37f25
commit d106869061
4 changed files with 81 additions and 65 deletions

View File

@ -23,18 +23,17 @@
import base64 import base64
import binascii import binascii
import click import click
import functools
import json import json
import os import os
import sys import sys
from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException, format_protobuf from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException, format_protobuf
from trezorlib.transport import get_transport, enumerate_devices, TransportException from trezorlib.transport import get_transport, enumerate_devices
from trezorlib import coins from trezorlib import coins
from trezorlib import messages as proto from trezorlib import messages as proto
from trezorlib import protobuf from trezorlib import protobuf
from trezorlib.ckd_public import PRIME_DERIVATION_FLAG
from trezorlib import stellar from trezorlib import stellar
from trezorlib import tools
class ChoiceType(click.Choice): class ChoiceType(click.Choice):
@ -449,7 +448,7 @@ def self_test(connect):
@click.pass_obj @click.pass_obj
def get_address(connect, coin, address, script_type, show_display): def get_address(connect, coin, address, script_type, show_display):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
return client.get_address(coin, address_n, show_display, script_type=script_type) return client.get_address(coin, address_n, show_display, script_type=script_type)
@ -461,7 +460,7 @@ def get_address(connect, coin, address, script_type, show_display):
@click.pass_obj @click.pass_obj
def get_public_node(connect, coin, address, curve, show_display): def get_public_node(connect, coin, address, curve, show_display):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin) result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin)
return { return {
'node': { 'node': {
@ -502,7 +501,7 @@ def sign_tx(connect, coin):
if address_n is None: if address_n is None:
pass pass
elif address_n[0] == (49 | PRIME_DERIVATION_FLAG): elif address_n[0] == tools.H_(49):
script_type = 'p2shsegwit' script_type = 'p2shsegwit'
return script_type return script_type
@ -518,7 +517,7 @@ def sign_tx(connect, coin):
if not prev: if not prev:
break break
prev_hash, prev_index = prev prev_hash, prev_index = prev
address_n = click.prompt('BIP-32 path to derive the key', type=client.expand_path) address_n = click.prompt('BIP-32 path to derive the key', type=tools.parse_path)
amount = click.prompt('Input amount (satoshis)', type=int, default=0) amount = click.prompt('Input amount (satoshis)', type=int, default=0)
sequence = click.prompt('Sequence Number to use (RBF opt-in enabled by default)', type=int, default=0xfffffffd) sequence = click.prompt('Sequence Number to use (RBF opt-in enabled by default)', type=int, default=0xfffffffd)
script_type = click.prompt('Input type', type=CHOICE_INPUT_SCRIPT_TYPE, default=default_script_type(address_n)) script_type = click.prompt('Input type', type=CHOICE_INPUT_SCRIPT_TYPE, default=default_script_type(address_n))
@ -540,7 +539,7 @@ def sign_tx(connect, coin):
address_n = None address_n = None
else: else:
address = None address = None
address_n = click.prompt('BIP-32 path (for change output)', type=client.expand_path, default='') address_n = click.prompt('BIP-32 path (for change output)', type=tools.parse_path, default='')
if not address_n: if not address_n:
break break
amount = click.prompt('Amount to spend (satoshis)', type=int) amount = click.prompt('Amount to spend (satoshis)', type=int)
@ -581,7 +580,7 @@ def sign_tx(connect, coin):
@click.pass_obj @click.pass_obj
def sign_message(connect, coin, address, message, script_type): def sign_message(connect, coin, address, message, script_type):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
typemap = { typemap = {
'address': proto.InputScriptType.SPENDADDRESS, 'address': proto.InputScriptType.SPENDADDRESS,
'segwit': proto.InputScriptType.SPENDWITNESS, 'segwit': proto.InputScriptType.SPENDWITNESS,
@ -613,7 +612,7 @@ def verify_message(connect, coin, address, signature, message):
@click.pass_obj @click.pass_obj
def ethereum_sign_message(connect, address, message): def ethereum_sign_message(connect, address, message):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
ret = client.ethereum_sign_message(address_n, message) ret = client.ethereum_sign_message(address_n, message)
output = { output = {
'message': message, 'message': message,
@ -648,7 +647,7 @@ def ethereum_verify_message(connect, address, signature, message):
@click.pass_obj @click.pass_obj
def encrypt_keyvalue(connect, address, key, value): def encrypt_keyvalue(connect, address, key, value):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
res = client.encrypt_keyvalue(address_n, key, value.encode()) res = client.encrypt_keyvalue(address_n, key, value.encode())
return binascii.hexlify(res) return binascii.hexlify(res)
@ -660,7 +659,7 @@ def encrypt_keyvalue(connect, address, key, value):
@click.pass_obj @click.pass_obj
def decrypt_keyvalue(connect, address, key, value): def decrypt_keyvalue(connect, address, key, value):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
return client.decrypt_keyvalue(address_n, key, binascii.unhexlify(value)) return client.decrypt_keyvalue(address_n, key, binascii.unhexlify(value))
@ -674,7 +673,7 @@ def decrypt_keyvalue(connect, address, key, value):
def encrypt_message(connect, coin, display_only, address, pubkey, message): def encrypt_message(connect, coin, display_only, address, pubkey, message):
client = connect() client = connect()
pubkey = binascii.unhexlify(pubkey) pubkey = binascii.unhexlify(pubkey)
address_n = client.expand_path(address) address_n = tools.parse_path(address)
res = client.encrypt_message(pubkey, message, display_only, coin, address_n) res = client.encrypt_message(pubkey, message, display_only, coin, address_n)
return { return {
'nonce': binascii.hexlify(res.nonce), 'nonce': binascii.hexlify(res.nonce),
@ -690,7 +689,7 @@ def encrypt_message(connect, coin, display_only, address, pubkey, message):
@click.pass_obj @click.pass_obj
def decrypt_message(connect, address, payload): def decrypt_message(connect, address, payload):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
payload = base64.b64decode(payload) payload = base64.b64decode(payload)
nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:] nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:]
return client.decrypt_message(address_n, nonce, message, msg_hmac) return client.decrypt_message(address_n, nonce, message, msg_hmac)
@ -707,7 +706,7 @@ def decrypt_message(connect, address, payload):
@click.pass_obj @click.pass_obj
def ethereum_get_address(connect, address, show_display): def ethereum_get_address(connect, address, show_display):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
address = client.ethereum_get_address(address_n, show_display) address = client.ethereum_get_address(address_n, show_display)
return '0x%s' % binascii.hexlify(address).decode() return '0x%s' % binascii.hexlify(address).decode()
@ -774,7 +773,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
to_address = ethereum_decode_hex(to) to_address = ethereum_decode_hex(to)
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)).decode()) address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)).decode())
if gas_price is None or gas_limit is None or nonce is None or publish: if gas_price is None or gas_limit is None or nonce is None or publish:
@ -836,7 +835,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
@click.pass_obj @click.pass_obj
def nem_get_address(connect, address, network, show_display): def nem_get_address(connect, address, network, show_display):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
return client.nem_get_address(address_n, network, show_display) return client.nem_get_address(address_n, network, show_display)
@ -847,7 +846,7 @@ def nem_get_address(connect, address, network, show_display):
@click.pass_obj @click.pass_obj
def nem_sign_tx(connect, address, file, broadcast): def nem_sign_tx(connect, address, file, broadcast):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
transaction = client.nem_sign_tx(address_n, json.load(file)) transaction = client.nem_sign_tx(address_n, json.load(file))
payload = { payload = {
@ -873,7 +872,7 @@ def nem_sign_tx(connect, address, file, broadcast):
@click.pass_obj @click.pass_obj
def lisk_get_address(connect, address, show_display): def lisk_get_address(connect, address, show_display):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
return client.lisk_get_address(address_n, show_display) return client.lisk_get_address(address_n, show_display)
@ -883,7 +882,7 @@ def lisk_get_address(connect, address, show_display):
@click.pass_obj @click.pass_obj
def lisk_get_public_key(connect, address, show_display): def lisk_get_public_key(connect, address, show_display):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
res = client.lisk_get_public_key(address_n, show_display) res = client.lisk_get_public_key(address_n, show_display)
output = { output = {
"public_key": binascii.hexlify(res.public_key).decode() "public_key": binascii.hexlify(res.public_key).decode()
@ -897,7 +896,7 @@ def lisk_get_public_key(connect, address, show_display):
@click.pass_obj @click.pass_obj
def lisk_sign_message(connect, address, message): def lisk_sign_message(connect, address, message):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
res = client.lisk_sign_message(address_n, message) res = client.lisk_sign_message(address_n, message)
output = { output = {
'message': message, 'message': message,
@ -925,7 +924,7 @@ def lisk_verify_message(connect, pubkey, signature, message):
@click.pass_obj @click.pass_obj
def lisk_sign_tx(connect, address, file): def lisk_sign_tx(connect, address, file):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
transaction = client.lisk_sign_tx(address_n, json.load(file)) transaction = client.lisk_sign_tx(address_n, json.load(file))
payload = { payload = {
@ -946,7 +945,7 @@ def lisk_sign_tx(connect, address, file):
@click.pass_obj @click.pass_obj
def cosi_commit(connect, address, data): def cosi_commit(connect, address, data):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
return client.cosi_commit(address_n, binascii.unhexlify(data)) return client.cosi_commit(address_n, binascii.unhexlify(data))
@ -958,7 +957,7 @@ def cosi_commit(connect, address, data):
@click.pass_obj @click.pass_obj
def cosi_sign(connect, address, data, global_commitment, global_pubkey): def cosi_sign(connect, address, data, global_commitment, global_pubkey):
client = connect() client = connect()
address_n = client.expand_path(address) address_n = tools.parse_path(address)
return client.cosi_sign(address_n, binascii.unhexlify(data), binascii.unhexlify(global_commitment), binascii.unhexlify(global_pubkey)) return client.cosi_sign(address_n, binascii.unhexlify(data), binascii.unhexlify(global_commitment), binascii.unhexlify(global_pubkey))

View File

@ -34,7 +34,6 @@ from . import messages as proto
from . import tools from . import tools
from . import mapping from . import mapping
from . import nem from . import nem
from .coins import slip44
from . import stellar from . import stellar
from .debuglink import DebugLink from .debuglink import DebugLink
from .protobuf import MessageType from .protobuf import MessageType
@ -490,7 +489,6 @@ class DebugLinkMixin(object):
class ProtocolMixin(object): class ProtocolMixin(object):
PRIME_DERIVATION_FLAG = 0x80000000
VENDORS = ('bitcointrezor.com', 'trezor.io') VENDORS = ('bitcointrezor.com', 'trezor.io')
def __init__(self, state=None, *args, **kwargs): def __init__(self, state=None, *args, **kwargs):
@ -513,44 +511,15 @@ class ProtocolMixin(object):
def _get_local_entropy(self): def _get_local_entropy(self):
return os.urandom(32) return os.urandom(32)
def _convert_prime(self, n): @staticmethod
def _convert_prime(n: tools.Address) -> tools.Address:
# Convert minus signs to uint32 with flag # Convert minus signs to uint32 with flag
return [int(abs(x) | self.PRIME_DERIVATION_FLAG) if x < 0 else x for x in n] return [tools.H_(int(abs(x))) if x < 0 else x for x in n]
@staticmethod @staticmethod
def expand_path(n): def expand_path(n):
# Convert string of bip32 path to list of uint32 integers with prime flags warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning)
# 0/-1/1' -> [0, 0x80000001, 0x80000001] return tools.parse_path(n)
if not n:
return []
n = n.split('/')
# m/a/b/c => a/b/c
if n[0] == 'm':
n = n[1:]
# coin_name/a/b/c => 44'/SLIP44_constant'/a/b/c
if n[0] in slip44:
n = ["44'", "%d'" % slip44[n[0]]] + n[1:]
path = []
for x in n:
prime = False
if x.endswith("'"):
x = x.replace('\'', '')
prime = True
if x.startswith('-'):
prime = True
x = abs(int(x))
if prime:
x |= ProtocolMixin.PRIME_DERIVATION_FLAG
path.append(x)
return path
@expect(proto.PublicKey) @expect(proto.PublicKey)
def get_public_node(self, n, ecdsa_curve_name=None, show_display=False, coin_name=None): def get_public_node(self, n, ecdsa_curve_name=None, show_display=False, coin_name=None):

View File

@ -30,8 +30,6 @@ from ecdsa.ellipticcurve import Point, INFINITY
from trezorlib import tools from trezorlib import tools
from trezorlib import messages from trezorlib import messages
PRIME_DERIVATION_FLAG = 0x80000000
def point_to_pubkey(point): def point_to_pubkey(point):
order = SECP256k1.order order = SECP256k1.order
@ -61,7 +59,7 @@ def sec_to_public_pair(pubkey):
def is_prime(n): def is_prime(n):
return (bool)(n & PRIME_DERIVATION_FLAG) return bool(n & tools.HARDENED_FLAG)
def fingerprint(pubkey): def fingerprint(pubkey):

View File

@ -18,9 +18,21 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>. # along with this library. If not, see <http://www.gnu.org/licenses/>.
import hashlib import hashlib
import binascii
import struct import struct
import sys from typing import NewType, List
from .coins import slip44
HARDENED_FLAG = 1 << 31
Address = NewType('Address', List[int])
def H_(x: int) -> int:
"""
Shortcut function that "hardens" a number in a BIP44 path.
"""
return x | HARDENED_FLAG
def Hash(data): def Hash(data):
@ -109,3 +121,41 @@ def b58decode(v, length):
return None return None
return result return result
def parse_path(nstr: str) -> Address:
"""
Convert BIP32 path string to list of uint32 integers with hardened flags.
Several conventions are supported to set the hardened flag: -1, 1', 1h
e.g.: "0/1h/1" -> [0, 0x80000001, 1]
:param nstr: path string
:return: list of integers
"""
if not nstr:
return []
n = nstr.split('/')
# m/a/b/c => a/b/c
if n[0] == 'm':
n = n[1:]
# coin_name/a/b/c => 44'/SLIP44_constant'/a/b/c
if n[0] in slip44:
coin_id = slip44[n[0]]
n[0:1] = ['44h', '{}h'.format(coin_id)]
def str_to_harden(x: str) -> int:
if x.startswith('-'):
return H_(abs(int(x)))
elif x.endswith(('h', "'")):
return H_(int(x[:-1]))
else:
return int(x)
try:
return list(str_to_harden(x) for x in n)
except Exception:
raise ValueError('Invalid BIP32 path', nstr)