1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 12:00:59 +00:00

trezorlib: deprecate client.expand_path and move the staticmethod

to an ordinary function tools.parse_path

Also remove PRIME_DERIVATION_FLAG and move it to tools.HARDENED_FLAG
This commit is contained in:
matejcik 2018-04-18 15:00:59 +02:00
parent 4f66b37f25
commit d106869061
4 changed files with 81 additions and 65 deletions

View File

@ -23,18 +23,17 @@
import base64
import binascii
import click
import functools
import json
import os
import sys
from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException, format_protobuf
from trezorlib.transport import get_transport, enumerate_devices, TransportException
from trezorlib.transport import get_transport, enumerate_devices
from trezorlib import coins
from trezorlib import messages as proto
from trezorlib import protobuf
from trezorlib.ckd_public import PRIME_DERIVATION_FLAG
from trezorlib import stellar
from trezorlib import tools
class ChoiceType(click.Choice):
@ -449,7 +448,7 @@ def self_test(connect):
@click.pass_obj
def get_address(connect, coin, address, script_type, show_display):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
return client.get_address(coin, address_n, show_display, script_type=script_type)
@ -461,7 +460,7 @@ def get_address(connect, coin, address, script_type, show_display):
@click.pass_obj
def get_public_node(connect, coin, address, curve, show_display):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin)
return {
'node': {
@ -502,7 +501,7 @@ def sign_tx(connect, coin):
if address_n is None:
pass
elif address_n[0] == (49 | PRIME_DERIVATION_FLAG):
elif address_n[0] == tools.H_(49):
script_type = 'p2shsegwit'
return script_type
@ -518,7 +517,7 @@ def sign_tx(connect, coin):
if not prev:
break
prev_hash, prev_index = prev
address_n = click.prompt('BIP-32 path to derive the key', type=client.expand_path)
address_n = click.prompt('BIP-32 path to derive the key', type=tools.parse_path)
amount = click.prompt('Input amount (satoshis)', type=int, default=0)
sequence = click.prompt('Sequence Number to use (RBF opt-in enabled by default)', type=int, default=0xfffffffd)
script_type = click.prompt('Input type', type=CHOICE_INPUT_SCRIPT_TYPE, default=default_script_type(address_n))
@ -540,7 +539,7 @@ def sign_tx(connect, coin):
address_n = None
else:
address = None
address_n = click.prompt('BIP-32 path (for change output)', type=client.expand_path, default='')
address_n = click.prompt('BIP-32 path (for change output)', type=tools.parse_path, default='')
if not address_n:
break
amount = click.prompt('Amount to spend (satoshis)', type=int)
@ -581,7 +580,7 @@ def sign_tx(connect, coin):
@click.pass_obj
def sign_message(connect, coin, address, message, script_type):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
typemap = {
'address': proto.InputScriptType.SPENDADDRESS,
'segwit': proto.InputScriptType.SPENDWITNESS,
@ -613,7 +612,7 @@ def verify_message(connect, coin, address, signature, message):
@click.pass_obj
def ethereum_sign_message(connect, address, message):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
ret = client.ethereum_sign_message(address_n, message)
output = {
'message': message,
@ -648,7 +647,7 @@ def ethereum_verify_message(connect, address, signature, message):
@click.pass_obj
def encrypt_keyvalue(connect, address, key, value):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
res = client.encrypt_keyvalue(address_n, key, value.encode())
return binascii.hexlify(res)
@ -660,7 +659,7 @@ def encrypt_keyvalue(connect, address, key, value):
@click.pass_obj
def decrypt_keyvalue(connect, address, key, value):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
return client.decrypt_keyvalue(address_n, key, binascii.unhexlify(value))
@ -674,7 +673,7 @@ def decrypt_keyvalue(connect, address, key, value):
def encrypt_message(connect, coin, display_only, address, pubkey, message):
client = connect()
pubkey = binascii.unhexlify(pubkey)
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
res = client.encrypt_message(pubkey, message, display_only, coin, address_n)
return {
'nonce': binascii.hexlify(res.nonce),
@ -690,7 +689,7 @@ def encrypt_message(connect, coin, display_only, address, pubkey, message):
@click.pass_obj
def decrypt_message(connect, address, payload):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
payload = base64.b64decode(payload)
nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:]
return client.decrypt_message(address_n, nonce, message, msg_hmac)
@ -707,7 +706,7 @@ def decrypt_message(connect, address, payload):
@click.pass_obj
def ethereum_get_address(connect, address, show_display):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
address = client.ethereum_get_address(address_n, show_display)
return '0x%s' % binascii.hexlify(address).decode()
@ -774,7 +773,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
to_address = ethereum_decode_hex(to)
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)).decode())
if gas_price is None or gas_limit is None or nonce is None or publish:
@ -836,7 +835,7 @@ def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_pri
@click.pass_obj
def nem_get_address(connect, address, network, show_display):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
return client.nem_get_address(address_n, network, show_display)
@ -847,7 +846,7 @@ def nem_get_address(connect, address, network, show_display):
@click.pass_obj
def nem_sign_tx(connect, address, file, broadcast):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
transaction = client.nem_sign_tx(address_n, json.load(file))
payload = {
@ -873,7 +872,7 @@ def nem_sign_tx(connect, address, file, broadcast):
@click.pass_obj
def lisk_get_address(connect, address, show_display):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
return client.lisk_get_address(address_n, show_display)
@ -883,7 +882,7 @@ def lisk_get_address(connect, address, show_display):
@click.pass_obj
def lisk_get_public_key(connect, address, show_display):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
res = client.lisk_get_public_key(address_n, show_display)
output = {
"public_key": binascii.hexlify(res.public_key).decode()
@ -897,7 +896,7 @@ def lisk_get_public_key(connect, address, show_display):
@click.pass_obj
def lisk_sign_message(connect, address, message):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
res = client.lisk_sign_message(address_n, message)
output = {
'message': message,
@ -925,7 +924,7 @@ def lisk_verify_message(connect, pubkey, signature, message):
@click.pass_obj
def lisk_sign_tx(connect, address, file):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
transaction = client.lisk_sign_tx(address_n, json.load(file))
payload = {
@ -946,7 +945,7 @@ def lisk_sign_tx(connect, address, file):
@click.pass_obj
def cosi_commit(connect, address, data):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
return client.cosi_commit(address_n, binascii.unhexlify(data))
@ -958,7 +957,7 @@ def cosi_commit(connect, address, data):
@click.pass_obj
def cosi_sign(connect, address, data, global_commitment, global_pubkey):
client = connect()
address_n = client.expand_path(address)
address_n = tools.parse_path(address)
return client.cosi_sign(address_n, binascii.unhexlify(data), binascii.unhexlify(global_commitment), binascii.unhexlify(global_pubkey))

View File

@ -34,7 +34,6 @@ from . import messages as proto
from . import tools
from . import mapping
from . import nem
from .coins import slip44
from . import stellar
from .debuglink import DebugLink
from .protobuf import MessageType
@ -490,7 +489,6 @@ class DebugLinkMixin(object):
class ProtocolMixin(object):
PRIME_DERIVATION_FLAG = 0x80000000
VENDORS = ('bitcointrezor.com', 'trezor.io')
def __init__(self, state=None, *args, **kwargs):
@ -513,44 +511,15 @@ class ProtocolMixin(object):
def _get_local_entropy(self):
return os.urandom(32)
def _convert_prime(self, n):
@staticmethod
def _convert_prime(n: tools.Address) -> tools.Address:
# Convert minus signs to uint32 with flag
return [int(abs(x) | self.PRIME_DERIVATION_FLAG) if x < 0 else x for x in n]
return [tools.H_(int(abs(x))) if x < 0 else x for x in n]
@staticmethod
def expand_path(n):
# Convert string of bip32 path to list of uint32 integers with prime flags
# 0/-1/1' -> [0, 0x80000001, 0x80000001]
if not n:
return []
n = n.split('/')
# m/a/b/c => a/b/c
if n[0] == 'm':
n = n[1:]
# coin_name/a/b/c => 44'/SLIP44_constant'/a/b/c
if n[0] in slip44:
n = ["44'", "%d'" % slip44[n[0]]] + n[1:]
path = []
for x in n:
prime = False
if x.endswith("'"):
x = x.replace('\'', '')
prime = True
if x.startswith('-'):
prime = True
x = abs(int(x))
if prime:
x |= ProtocolMixin.PRIME_DERIVATION_FLAG
path.append(x)
return path
warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning)
return tools.parse_path(n)
@expect(proto.PublicKey)
def get_public_node(self, n, ecdsa_curve_name=None, show_display=False, coin_name=None):

View File

@ -30,8 +30,6 @@ from ecdsa.ellipticcurve import Point, INFINITY
from trezorlib import tools
from trezorlib import messages
PRIME_DERIVATION_FLAG = 0x80000000
def point_to_pubkey(point):
order = SECP256k1.order
@ -61,7 +59,7 @@ def sec_to_public_pair(pubkey):
def is_prime(n):
return (bool)(n & PRIME_DERIVATION_FLAG)
return bool(n & tools.HARDENED_FLAG)
def fingerprint(pubkey):

View File

@ -18,9 +18,21 @@
# along with this library. If not, see <http://www.gnu.org/licenses/>.
import hashlib
import binascii
import struct
import sys
from typing import NewType, List
from .coins import slip44
HARDENED_FLAG = 1 << 31
Address = NewType('Address', List[int])
def H_(x: int) -> int:
"""
Shortcut function that "hardens" a number in a BIP44 path.
"""
return x | HARDENED_FLAG
def Hash(data):
@ -109,3 +121,41 @@ def b58decode(v, length):
return None
return result
def parse_path(nstr: str) -> Address:
"""
Convert BIP32 path string to list of uint32 integers with hardened flags.
Several conventions are supported to set the hardened flag: -1, 1', 1h
e.g.: "0/1h/1" -> [0, 0x80000001, 1]
:param nstr: path string
:return: list of integers
"""
if not nstr:
return []
n = nstr.split('/')
# m/a/b/c => a/b/c
if n[0] == 'm':
n = n[1:]
# coin_name/a/b/c => 44'/SLIP44_constant'/a/b/c
if n[0] in slip44:
coin_id = slip44[n[0]]
n[0:1] = ['44h', '{}h'.format(coin_id)]
def str_to_harden(x: str) -> int:
if x.startswith('-'):
return H_(abs(int(x)))
elif x.endswith(('h', "'")):
return H_(int(x[:-1]))
else:
return int(x)
try:
return list(str_to_harden(x) for x in n)
except Exception:
raise ValueError('Invalid BIP32 path', nstr)