#!/usr/bin/env python3

# This file is part of the TREZOR project.
#
# Copyright (C) 2012-2017 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2017 Pavol Rusnak <stick@satoshilabs.com>
# Copyright (C) 2016-2017 Jochen Hoenicke <hoenicke@gmail.com>
# Copyright (C) 2017      mruddy
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library.  If not, see <http://www.gnu.org/licenses/>.

import base64
import binascii
import click
import functools
import json
import os
import sys

from trezorlib.client import TrezorClient, TrezorClientVerbose, CallException, format_protobuf
from trezorlib.transport import get_transport, enumerate_devices, TransportException
from trezorlib import messages as proto
from trezorlib import protobuf
from trezorlib.coins import coins_txapi
from trezorlib.ckd_public import PRIME_DERIVATION_FLAG
from trezorlib import stellar


class ChoiceType(click.Choice):
    def __init__(self, typemap):
        super(ChoiceType, self).__init__(typemap.keys())
        self.typemap = typemap

    def convert(self, value, param, ctx):
        value = super(ChoiceType, self).convert(value, param, ctx)
        return self.typemap[value]


CHOICE_RECOVERY_DEVICE_TYPE = ChoiceType({
    'scrambled':  proto.RecoveryDeviceType.ScrambledWords,
    'matrix':     proto.RecoveryDeviceType.Matrix,
})

CHOICE_INPUT_SCRIPT_TYPE = ChoiceType({
    'address':    proto.InputScriptType.SPENDADDRESS,
    'segwit':     proto.InputScriptType.SPENDWITNESS,
    'p2shsegwit': proto.InputScriptType.SPENDP2SHWITNESS,
})

CHOICE_OUTPUT_SCRIPT_TYPE = ChoiceType({
    'address':    proto.OutputScriptType.PAYTOADDRESS,
    'segwit':     proto.OutputScriptType.PAYTOWITNESS,
    'p2shsegwit': proto.OutputScriptType.PAYTOP2SHWITNESS,
})


@click.group(context_settings={'max_content_width': 400})
@click.option('-p', '--path', help='Select device by specific path.', default=os.environ.get('TREZOR_PATH'))
@click.option('-v', '--verbose', is_flag=True, help='Show communication messages.')
@click.option('-j', '--json', 'is_json', is_flag=True, help='Print result as JSON object')
@click.pass_context
def cli(ctx, path, verbose, is_json):
    if ctx.invoked_subcommand != 'list':
        if verbose:
            cls = TrezorClientVerbose
        else:
            cls = TrezorClient

        def get_device():
            try:
                device = get_transport(path, prefix_search=False)
            except:
                try:
                    device = get_transport(path, prefix_search=True)
                except:
                    click.echo("Failed to find a TREZOR device.")
                    if path is not None:
                        click.echo("Using path: {}".format(path))
                    sys.exit(1)
            return cls(device)

        ctx.obj = get_device


@cli.resultcallback()
def print_result(res, path, verbose, is_json):
    if is_json:
        if isinstance(res, protobuf.MessageType):
            click.echo(json.dumps({res.__class__.__name__: res.__dict__}))
        else:
            click.echo(json.dumps(res, sort_keys=True, indent=4))
    else:
        if isinstance(res, list):
            for line in res:
                click.echo(line)
        elif isinstance(res, dict):
            for k, v in res.items():
                if isinstance(v, dict):
                    for kk, vv in v.items():
                        click.echo('%s.%s: %s' % (k, kk, vv))
                else:
                    click.echo('%s: %s' % (k, v))
        elif isinstance(res, protobuf.MessageType):
            click.echo(format_protobuf(res))
        else:
            click.echo(res)


#
# Common functions
#


@cli.command(name='list', help='List connected TREZOR devices.')
def ls():
    return enumerate_devices()


@cli.command(help='Show version of trezorctl/trezorlib.')
def version():
    from trezorlib import __version__ as VERSION
    return VERSION


#
# Basic device functions
#


@cli.command(help='Send ping message.')
@click.argument('message')
@click.option('-b', '--button-protection', is_flag=True)
@click.option('-p', '--pin-protection', is_flag=True)
@click.option('-r', '--passphrase-protection', is_flag=True)
@click.pass_obj
def ping(connect, message, button_protection, pin_protection, passphrase_protection):
    return connect().ping(message, button_protection=button_protection, pin_protection=pin_protection, passphrase_protection=passphrase_protection)


@cli.command(help='Clear session (remove cached PIN, passphrase, etc.).')
@click.pass_obj
def clear_session(connect):
    return connect().clear_session()


@cli.command(help='Get example entropy.')
@click.argument('size', type=int)
@click.pass_obj
def get_entropy(connect, size):
    return binascii.hexlify(connect().get_entropy(size))


@cli.command(help='Retrieve device features and settings.')
@click.pass_obj
def get_features(connect):
    return connect().features


@cli.command(help='List all supported coin types by the device.')
@click.pass_obj
def list_coins(connect):
    return [coin.coin_name for coin in connect().features.coins]


#
# Device management functions
#


@cli.command(help='Change new PIN or remove existing.')
@click.option('-r', '--remove', is_flag=True)
@click.pass_obj
def change_pin(connect, remove):
    return connect().change_pin(remove)


@cli.command(help='Enable passphrase.')
@click.pass_obj
def enable_passphrase(connect):
    return connect().apply_settings(use_passphrase=True)


@cli.command(help='Disable passphrase.')
@click.pass_obj
def disable_passphrase(connect):
    return connect().apply_settings(use_passphrase=False)


@cli.command(help='Set new device label.')
@click.option('-l', '--label')
@click.pass_obj
def set_label(connect, label):
    return connect().apply_settings(label=label)


@cli.command(help='Set passphrase source.')
@click.argument('source', type=int)
@click.pass_obj
def set_passphrase_source(connect, source):
    return connect().apply_settings(passphrase_source=source)


@cli.command(help='Set device flags.')
@click.argument('flags')
@click.pass_obj
def set_flags(connect, flags):
    flags = flags.lower()
    if flags.startswith('0b'):
        flags = int(flags, 2)
    elif flags.startswith('0x'):
        flags = int(flags, 16)
    else:
        flags = int(flags)
    return connect().apply_flags(flags=flags)


@cli.command(help='Set new homescreen.')
@click.option('-f', '--filename', default=None)
@click.pass_obj
def set_homescreen(connect, filename):
    if filename is None:
        img = b'\x00'
    elif filename.endswith('.toif'):
        img = open(filename, 'rb').read()
        if img[:8] != b'TOIf\x90\x00\x90\x00':
            raise CallException(types.Failure_DataError, 'File is not a TOIF file with size of 144x144')
    else:
        from PIL import Image
        im = Image.open(filename)
        if im.size != (128, 64):
            raise CallException(proto.FailureType.DataError, 'Wrong size of the image')
        im = im.convert('1')
        pix = im.load()
        img = bytearray(1024)
        for j in range(64):
            for i in range(128):
                if pix[i, j]:
                    o = (i + j * 128)
                    img[o // 8] |= (1 << (7 - o % 8))
        img = bytes(img)
    return connect().apply_settings(homescreen=img)


@cli.command(help='Set U2F counter.')
@click.argument('counter', type=int)
@click.pass_obj
def set_u2f_counter(connect, counter):
    return connect().set_u2f_counter(counter)


@cli.command(help='Reset device to factory defaults and remove all private data.')
@click.pass_obj
def wipe_device(connect):
    return connect().wipe_device()


@cli.command(help='Load custom configuration to the device.')
@click.option('-m', '--mnemonic')
@click.option('-e', '--expand', is_flag=True)
@click.option('-x', '--xprv')
@click.option('-p', '--pin', default='')
@click.option('-r', '--passphrase-protection', is_flag=True)
@click.option('-l', '--label', default='')
@click.option('-i', '--ignore-checksum', is_flag=True)
@click.option('-s', '--slip0014', is_flag=True)
@click.pass_obj
def load_device(connect, mnemonic, expand, xprv, pin, passphrase_protection, label, ignore_checksum, slip0014):
    if not mnemonic and not xprv and not slip0014:
        raise CallException(proto.FailureType.DataError, 'Please provide mnemonic or xprv')

    client = connect()
    if mnemonic:
        return client.load_device_by_mnemonic(
            mnemonic,
            pin,
            passphrase_protection,
            label,
            'english',
            ignore_checksum,
            expand
        )
    if xprv:
        return client.load_device_by_xprv(
            xprv,
            pin,
            passphrase_protection,
            label,
            'english'
        )
    if slip0014:
        return client.load_device_by_mnemonic(
            ' '.join(['all'] * 12),
            pin,
            passphrase_protection,
            'SLIP-0014'
        )


@cli.command(help='Start safe recovery workflow.')
@click.option('-w', '--words', type=click.Choice(['12', '18', '24']), default='24')
@click.option('-e', '--expand', is_flag=True)
@click.option('-p', '--pin-protection', is_flag=True)
@click.option('-r', '--passphrase-protection', is_flag=True)
@click.option('-l', '--label')
@click.option('-t', '--type', 'rec_type', type=CHOICE_RECOVERY_DEVICE_TYPE, default='scrambled')
@click.option('-d', '--dry-run', is_flag=True)
@click.pass_obj
def recovery_device(connect, words, expand, pin_protection, passphrase_protection, label, rec_type, dry_run):
    return connect().recovery_device(
        int(words),
        passphrase_protection,
        pin_protection,
        label,
        'english',
        rec_type,
        expand,
        dry_run
    )


@cli.command(help='Perform device setup and generate new seed.')
@click.option('-e', '--entropy', is_flag=True)
@click.option('-t', '--strength', type=click.Choice(['128', '192', '256']), default='256')
@click.option('-r', '--passphrase-protection', is_flag=True)
@click.option('-p', '--pin-protection', is_flag=True)
@click.option('-l', '--label')
@click.option('-u', '--u2f-counter', default=0)
@click.option('-s', '--skip-backup', is_flag=True)
@click.pass_obj
def reset_device(connect, entropy, strength, passphrase_protection, pin_protection, label, u2f_counter, skip_backup):
    return connect().reset_device(
        entropy,
        int(strength),
        passphrase_protection,
        pin_protection,
        label,
        'english',
        u2f_counter,
        skip_backup
    )


@cli.command(help='Perform device seed backup.')
@click.pass_obj
def backup_device(connect):
    return connect().backup_device()


#
# Firmware update
#


@cli.command(help='Upload new firmware to device (must be in bootloader mode).')
@click.option('-f', '--filename')
@click.option('-u', '--url')
@click.option('-v', '--version')
@click.option('-s', '--skip-check', is_flag=True)
@click.option('-e', '--erase', is_flag=True)
@click.pass_obj
def firmware_update(connect, filename, url, version, skip_check, erase):
    if filename:
        fp = open(filename, 'rb').read()
    elif url:
        import requests
        click.echo('Downloading from', url)
        r = requests.get(url)
        fp = r.content
    elif erase:
        fp = 32768 * b'\xFF'
    else:
        import requests
        r = requests.get('https://wallet.trezor.io/data/firmware/releases.json')
        releases = r.json()

        def version_func(r):
            return r['version']

        def version_string(r):
            return '.'.join(map(str, version_func(r)))

        if version:
            release = next((r for r in releases if version_string(r) == version))
        else:
            release = max(releases, key=version_func)
            click.echo('Fetching version: %s' % version_string(release))
        click.echo('Firmware fingerprint: %s' % release['fingerprint'])
        url = 'https://wallet.trezor.io/' + release['url']
        click.echo('Downloading from %s' % url)
        r = requests.get(url)
        fp = r.content

    if not skip_check and not erase:
        if fp[:8] == b'54525a52' or fp[:8] == b'54525a56':
            fp = binascii.unhexlify(fp)
        if fp[:4] != b'TRZR' and fp[:4] != b'TRZV':
            raise CallException(proto.FailureType.FirmwareError, 'TREZOR firmware header expected')

    click.echo('Please confirm action on device...')

    from io import BytesIO
    return connect().firmware_update(fp=BytesIO(fp))


@cli.command(help='Perform a self-test.')
@click.pass_obj
def self_test(connect):
    return connect().self_test()


#
# Basic coin functions
#


@cli.command(help='Get address for specified path.')
@click.option('-c', '--coin', default='Bitcoin')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
@click.option('-t', '--script-type', type=CHOICE_INPUT_SCRIPT_TYPE, default='address')
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def get_address(connect, coin, address, script_type, show_display):
    client = connect()
    address_n = client.expand_path(address)
    return client.get_address(coin, address_n, show_display, script_type=script_type)


@cli.command(help='Get public node of given path.')
@click.option('-c', '--coin', default='Bitcoin')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'")
@click.option('-e', '--curve')
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def get_public_node(connect, coin, address, curve, show_display):
    client = connect()
    address_n = client.expand_path(address)
    result = client.get_public_node(address_n, ecdsa_curve_name=curve, show_display=show_display, coin_name=coin)
    return {
        'node': {
            'depth': result.node.depth,
            'fingerprint': "%08x" % result.node.fingerprint,
            'child_num': result.node.child_num,
            'chain_code': binascii.hexlify(result.node.chain_code),
            'public_key': binascii.hexlify(result.node.public_key),
        },
        'xpub': result.xpub
    }


#
# Signing options
#

@cli.command(help='Sign transaction.')
@click.option('-c', '--coin', default='Bitcoin')
# @click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
# @click.option('-t', '--script-type', type=CHOICE_INPUT_SCRIPT_TYPE, default='address')
# @click.option('-o', '--output', required=True, help='Transaction output')
# @click.option('-f', '--fee', required=True, help='Transaction fee (sat/B)')
@click.pass_obj
def sign_tx(connect, coin):
    client = connect()
    if coin in coins_txapi:
        txapi = coins_txapi[coin]
    else:
        click.echo('Coin "%s" is not recognized.' % coin, err=True)
        click.echo('Supported coin types: %s' % ', '.join(coins_txapi.keys()), err=True)
        sys.exit(1)

    client.set_tx_api(txapi)

    def default_script_type(address_n):
        script_type = 'address'

        if address_n is None:
            pass
        elif address_n[0] == (49 | PRIME_DERIVATION_FLAG):
            script_type = 'p2shsegwit'

        return script_type

    def outpoint(s):
        txid, vout = s.split(':')
        return binascii.unhexlify(txid), int(vout)

    inputs = []
    while True:
        click.echo()
        prev = click.prompt('Previous output to spend (txid:vout)', type=outpoint, default='')
        if not prev:
            break
        prev_hash, prev_index = prev
        address_n = click.prompt('BIP-32 path to derive the key', type=client.expand_path)
        amount = click.prompt('Input amount (satoshis)', type=int, default=0)
        sequence = click.prompt('Sequence Number to use (RBF opt-in enabled by default)', type=int, default=0xfffffffd)
        script_type = click.prompt('Input type', type=CHOICE_INPUT_SCRIPT_TYPE, default=default_script_type(address_n))
        script_type = script_type if isinstance(script_type, int) else CHOICE_INPUT_SCRIPT_TYPE.typemap[script_type]
        inputs.append(proto.TxInputType(
            address_n=address_n,
            prev_hash=prev_hash,
            prev_index=prev_index,
            amount=amount,
            script_type=script_type,
            sequence=sequence,
        ))

    outputs = []
    while True:
        click.echo()
        address = click.prompt('Output address (for non-change output)', default='')
        if address:
            address_n = None
        else:
            address = None
            address_n = click.prompt('BIP-32 path (for change output)', type=client.expand_path, default='')
            if not address_n:
                break
        amount = click.prompt('Amount to spend (satoshis)', type=int)
        script_type = click.prompt('Output type', type=CHOICE_OUTPUT_SCRIPT_TYPE, default=default_script_type(address_n))
        script_type = script_type if isinstance(script_type, int) else CHOICE_OUTPUT_SCRIPT_TYPE.typemap[script_type]
        outputs.append(proto.TxOutputType(
            address_n=address_n,
            address=address,
            amount=amount,
            script_type=script_type,
        ))

    tx_version = click.prompt('Transaction version', type=int, default=2)
    tx_locktime = click.prompt('Transaction locktime', type=int, default=0)

    (signatures, serialized_tx) = client.sign_tx(coin, inputs, outputs, tx_version, tx_locktime)

    client.close()

    click.echo()
    click.echo('Signed Transaction:')
    click.echo(binascii.hexlify(serialized_tx))
    click.echo()
    click.echo('Use the following form to broadcast it to the network:')
    click.echo(txapi.pushtx_url)


#
# Message functions
#


@cli.command(help='Sign message using address of given path.')
@click.option('-c', '--coin', default='Bitcoin')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
@click.option('-t', '--script-type', type=click.Choice(['address', 'segwit', 'p2shsegwit']), default='address')
@click.argument('message')
@click.pass_obj
def sign_message(connect, coin, address, message, script_type):
    client = connect()
    address_n = client.expand_path(address)
    typemap = {
        'address': proto.InputScriptType.SPENDADDRESS,
        'segwit': proto.InputScriptType.SPENDWITNESS,
        'p2shsegwit': proto.InputScriptType.SPENDP2SHWITNESS,
    }
    script_type = typemap[script_type]
    res = client.sign_message(coin, address_n, message, script_type)
    return {
        'message': message,
        'address': res.address,
        'signature': base64.b64encode(res.signature)
    }


@cli.command(help='Verify message.')
@click.option('-c', '--coin', default='Bitcoin')
@click.argument('address')
@click.argument('signature')
@click.argument('message')
@click.pass_obj
def verify_message(connect, coin, address, signature, message):
    signature = base64.b64decode(signature)
    return connect().verify_message(coin, address, signature, message)


@cli.command(help='Sign message with Ethereum address.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0")
@click.argument('message')
@click.pass_obj
def ethereum_sign_message(connect, address, message):
    client = connect()
    address_n = client.expand_path(address)
    ret = client.ethereum_sign_message(address_n, message)
    output = {
        'message': message,
        'address': '0x%s' % binascii.hexlify(ret.address).decode(),
        'signature': '0x%s' % binascii.hexlify(ret.signature).decode()
    }
    return output


def ethereum_decode_hex(value):
    if value.startswith('0x') or value.startswith('0X'):
        return binascii.unhexlify(value[2:])
    else:
        return binascii.unhexlify(value)


@cli.command(help='Verify message signed with Ethereum address.')
@click.argument('address')
@click.argument('signature')
@click.argument('message')
@click.pass_obj
def ethereum_verify_message(connect, address, signature, message):
    address = ethereum_decode_hex(address)
    signature = ethereum_decode_hex(signature)
    return connect().ethereum_verify_message(address, signature, message)


@cli.command(help='Encrypt value by given key and path.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/10016'/0")
@click.argument('key')
@click.argument('value')
@click.pass_obj
def encrypt_keyvalue(connect, address, key, value):
    client = connect()
    address_n = client.expand_path(address)
    res = client.encrypt_keyvalue(address_n, key, value.encode())
    return binascii.hexlify(res)


@cli.command(help='Decrypt value by given key and path.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/10016'/0")
@click.argument('key')
@click.argument('value')
@click.pass_obj
def decrypt_keyvalue(connect, address, key, value):
    client = connect()
    address_n = client.expand_path(address)
    return client.decrypt_keyvalue(address_n, key, binascii.unhexlify(value))


@cli.command(help='Encrypt message.')
@click.option('-c', '--coin', default='Bitcoin')
@click.option('-d', '--display-only', is_flag=True)
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
@click.argument('pubkey')
@click.argument('message')
@click.pass_obj
def encrypt_message(connect, coin, display_only, address, pubkey, message):
    client = connect()
    pubkey = binascii.unhexlify(pubkey)
    address_n = client.expand_path(address)
    res = client.encrypt_message(pubkey, message, display_only, coin, address_n)
    return {
        'nonce': binascii.hexlify(res.nonce),
        'message': binascii.hexlify(res.message),
        'hmac': binascii.hexlify(res.hmac),
        'payload': base64.b64encode(res.nonce + res.message + res.hmac),
    }


@cli.command(help='Decrypt message.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
@click.argument('payload')
@click.pass_obj
def decrypt_message(connect, address, payload):
    client = connect()
    address_n = client.expand_path(address)
    payload = base64.b64decode(payload)
    nonce, message, msg_hmac = payload[:33], payload[33:-8], payload[-8:]
    return client.decrypt_message(address_n, nonce, message, msg_hmac)


#
# Ethereum functions
#


@cli.command(help='Get Ethereum address in hex encoding.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/60'/0'/0/0")
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def ethereum_get_address(connect, address, show_display):
    client = connect()
    address_n = client.expand_path(address)
    address = client.ethereum_get_address(address_n, show_display)
    return '0x%s' % binascii.hexlify(address).decode()


@cli.command(help='Sign (and optionally publish) Ethereum transaction. Use TO as destination address or set TO to "" for contract creation.')
@click.option('-a', '--host', default='localhost:8545', help='RPC port of ethereum node for automatic gas/nonce estimation and publishing')
@click.option('-c', '--chain-id', type=int, help='EIP-155 chain id (replay protection)')
@click.option('-n', '--address', required=True, help="BIP-32 path to source address, e.g., m/44'/60'/0'/0/0")
@click.option('-v', '--value', default='0', help='Ether amount to transfer, e.g. "100 milliether"')
@click.option('-g', '--gas-limit', type=int, help='Gas limit - Required for offline signing')
@click.option('-t', '--gas-price', help='Gas price, e.g. "20 nanoether" - Required for offline signing')
@click.option('-i', '--nonce', type=int, help='Transaction counter - Required for offline signing')
@click.option('-d', '--data', default='', help='Data as hex string, e.g. 0x12345678')
@click.option('-p', '--publish', is_flag=True, help='Publish transaction via RPC')
@click.argument('to')
@click.pass_obj
def ethereum_sign_tx(connect, host, chain_id, address, value, gas_limit, gas_price, nonce, data, publish, to):
    from ethjsonrpc import EthJsonRpc
    import rlp

    ether_units = {
        'wei':          1,
        'kwei':         1000,
        'babbage':      1000,
        'femtoether':   1000,
        'mwei':         1000000,
        'lovelace':     1000000,
        'picoether':    1000000,
        'gwei':         1000000000,
        'shannon':      1000000000,
        'nanoether':    1000000000,
        'nano':         1000000000,
        'szabo':        1000000000000,
        'microether':   1000000000000,
        'micro':        1000000000000,
        'finney':       1000000000000000,
        'milliether':   1000000000000000,
        'milli':        1000000000000000,
        'ether':        1000000000000000000,
        'eth':          1000000000000000000,
    }

    if ' ' in value:
        value, unit = value.split(' ', 1)
        if unit.lower() not in ether_units:
            raise CallException(proto.Failure_DataError, 'Unrecognized ether unit %r' % unit)
        value = int(value) * ether_units[unit.lower()]
    else:
        value = int(value)

    if gas_price is not None:
        if ' ' in gas_price:
            gas_price, unit = gas_price.split(' ', 1)
            if unit.lower() not in ether_units:
                raise CallException(proto.Failure_DataError, 'Unrecognized gas price unit %r' % unit)
            gas_price = int(gas_price) * ether_units[unit.lower()]
        else:
            gas_price = int(gas_price)

    if gas_limit is not None:
        gas_limit = int(gas_limit)

    to_address = ethereum_decode_hex(to)

    client = connect()
    address_n = client.expand_path(address)
    address = '0x%s' % (binascii.hexlify(client.ethereum_get_address(address_n)).decode())

    if gas_price is None or gas_limit is None or nonce is None:
        host, port = host.split(':')
        eth = EthJsonRpc(host, int(port))

    if not data:
        data = ''
    data = ethereum_decode_hex(data)

    if gas_price is None:
        gas_price = eth.eth_gasPrice()

    if gas_limit is None:
        gas_limit = eth.eth_estimateGas(
            to_address=to,
            from_address=address,
            value=('0x%x' % value),
            data='0x%s' % (binascii.hexlify(data).decode()))

    if nonce is None:
        nonce = eth.eth_getTransactionCount(address)

    sig = client.ethereum_sign_tx(
        n=address_n,
        nonce=nonce,
        gas_price=gas_price,
        gas_limit=gas_limit,
        to=to_address,
        value=value,
        data=data,
        chain_id=chain_id)

    transaction = rlp.encode(
        (nonce, gas_price, gas_limit, to_address, value, data) + sig)
    tx_hex = '0x%s' % binascii.hexlify(transaction).decode()

    if publish:
        tx_hash = eth.eth_sendRawTransaction(tx_hex)
        return 'Transaction published with ID: %s' % tx_hash
    else:
        return 'Signed raw transaction: %s' % tx_hex


#
# NEM functions
#


@cli.command(help='Get NEM address for specified path.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/43'/0/0")
@click.option('-N', '--network', type=int, default=0x68)
@click.option('-d', '--show-display', is_flag=True)
@click.pass_obj
def nem_get_address(connect, address, network, show_display):
    client = connect()
    address_n = client.expand_path(address)
    return client.nem_get_address(address_n, network, show_display)


@cli.command(help='Sign (and optionally broadcast) NEM transaction.')
@click.option('-n', '--address', help='BIP-32 path to signing key')
@click.option('-f', '--file', type=click.File('r'), default='-', help='Transaction in NIS (RequestPrepareAnnounce) format')
@click.option('-b', '--broadcast', help='NIS to announce transaction to')
@click.pass_obj
def nem_sign_tx(connect, address, file, broadcast):
    client = connect()
    address_n = client.expand_path(address)
    transaction = client.nem_sign_tx(address_n, json.load(file))

    payload = {
        "data": binascii.hexlify(transaction.data).decode(),
        "signature": binascii.hexlify(transaction.signature).decode()
    }

    if broadcast:
        import requests
        return requests.post("{}/transaction/announce".format(broadcast), json=payload).json()
    else:
        return payload


#
# CoSi functions
#


@cli.command(help='Ask device to commit to CoSi signing.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
@click.argument('data')
@click.pass_obj
def cosi_commit(connect, address, data):
    client = connect()
    address_n = client.expand_path(address)
    return client.cosi_commit(address_n, binascii.unhexlify(data))


@cli.command(help='Ask device to sign using CoSi.')
@click.option('-n', '--address', required=True, help="BIP-32 path, e.g. m/44'/0'/0'/0/0")
@click.argument('data')
@click.argument('global_commitment')
@click.argument('global_pubkey')
@click.pass_obj
def cosi_sign(connect, address, data, global_commitment, global_pubkey):
    client = connect()
    address_n = client.expand_path(address)
    return client.cosi_sign(address_n, binascii.unhexlify(data), binascii.unhexlify(global_commitment), binascii.unhexlify(global_pubkey))


#
# Stellar functions
#
@cli.command(help='Get Stellar public address')
@click.option('-n', '--address', required=False, help="BIP32 path. Default primary account is m/44'/148'/0'. Always use hardened paths and the m/44'/148'/ prefix")
@click.pass_obj
def stellar_get_address(connect, address):
    client = connect()
    address_n = stellar.expand_path_or_default(client, address)
    # StellarPublicKey response
    response = client.stellar_get_public_key(address_n)
    return stellar.address_from_public_key(response.public_key)

@cli.command(help='Sign a string with a Stellar key')
@click.option('-n', '--address', required=False, help="BIP32 path. Default primary account is m/44'/148'/0'. Always use hardened paths and the m/44'/148'/ prefix")
@click.argument('message')
@click.pass_obj
def stellar_sign_message(connect, address, message):
    client = connect()
    address_n = stellar.expand_path_or_default(client, address)
    response = client.stellar_sign_message(address_n, message)
    return base64.b64encode(response.signature)

@cli.command(help='Verify that a signature is valid')
@click.option('--message-is-b64/--no-message-is-b64', default=False, required=False, help="If set, the message argument will be interpreted as a base64-encoded string")
@click.argument('address')
@click.argument('signatureb64')
@click.argument('message')
@click.pass_obj
def stellar_verify_message(connect, address, message_is_b64, signatureb64, message):
    if message_is_b64:
        message = base64.b64decode(message)
    else:
        message = message.encode('utf-8')

    pubkey_bytes = stellar.address_to_public_key(address)

    client = connect()
    is_verified = client.stellar_verify_message(pubkey_bytes, base64.b64decode(signatureb64), message)

    if is_verified:
        return "Success: message verified"
    else:
        print("ERROR: invalid signature, verification failed")
        sys.exit(1)

@cli.command(help='Sign a base64-encoded transaction envelope')
@click.option('-n', '--address', required=False, help="BIP32 path. Default primary account is m/44'/148'/0'. Always use hardened paths and the m/44'/148'/ prefix")
@click.option('-n', '--network-passphrase', required=False, help="Network passphrase (blank for public network). Testnet is: 'Test SDF Network ; September 2015'")
@click.argument('b64envelope')
@click.pass_obj
def stellar_sign_transaction(connect, b64envelope, address, network_passphrase):
    client = connect()
    address_n = stellar.expand_path_or_default(client, address)
    # raw signature bytes
    resp = client.stellar_sign_transaction(base64.b64decode(b64envelope), address_n, network_passphrase)

    return base64.b64encode(resp.signature)

#
# Main
#


if __name__ == '__main__':
    cli()