# This file is part of the TREZOR project.
# Copyright (C) 2012-2016 Marek Palatinus <slush@satoshilabs.com>
# Copyright (C) 2012-2016 Pavol Rusnak <stick@satoshilabs.com>
# Copyright (C) 2016      Jochen Hoenicke <hoenicke@gmail.com>
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this library.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function, absolute_import

import os
import sys
import time
import binascii
import hashlib
import unicodedata
# import json
import getpass

from mnemonic import Mnemonic

from . import tools
# from . import mapping
from . import messages_pb2 as proto
from . import types_pb2 as types
from .debuglink import DebugLink

# Python2 vs Python3
    input = raw_input
except NameError:

# try:
#     from PIL import Image
#     SCREENSHOT = True
# except:
#     SCREENSHOT = False


DEFAULT_CURVE = 'secp256k1'

# monkeypatching: text formatting of protobuf messages

def getch():
        import termios
    except ImportError:
        # Non-POSIX. Return msvcrt's (Windows') getch.
        import msvcrt
        return msvcrt.getch()

    # POSIX system. Create and return a getch that manipulates the tty.
    import sys
    import tty

    def _getch():
        fd = sys.stdin.fileno()
        old_settings = termios.tcgetattr(fd)
            ch = sys.stdin.read(1)
            termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
        return ch

    return _getch()

def get_buttonrequest_value(code):
    # Converts integer code to its string representation of ButtonRequestType
    return [k for k, v in types.ButtonRequestType.items() if v == code][0]

def pprint(msg):
    msg_class = msg.__class__.__name__
    msg_size = msg.ByteSize()
    msg_ser = msg.SerializeToString()
    msg_id = mapping.get_type(msg)
    msg_json = json_format.MessageToDict(msg, preserving_proto_field_name=True)
    if isinstance(msg, proto.FirmwareUpload) or isinstance(msg, proto.SelfTest):
        return "<%s> (%d bytes):\n" % (msg_class, msg_size)
        return "<%s> (%d bytes):\n%s" % (msg_class, msg_size, msg)

def log(msg):

class CallException(Exception):
    def __init__(self, code, message):
        super(CallException, self).__init__()
        self.args = [code, message]

class PinException(CallException):

class field(object):
    # Decorator extracts single value from
    # protobuf object. If the field is not
    # present, raises an exception.
    def __init__(self, field):
        self.field = field

    def __call__(self, f):
        def wrapped_f(*args, **kwargs):
            ret = f(*args, **kwargs)
            return getattr(ret, self.field)
        return wrapped_f

class expect(object):
    # Decorator checks if the method
    # returned one of expected protobuf messages
    # or raises an exception
    def __init__(self, *expected):
        self.expected = expected

    def __call__(self, f):
        def wrapped_f(*args, **kwargs):
            ret = f(*args, **kwargs)
            if not isinstance(ret, self.expected):
                raise Exception("Got %s, expected %s" % (ret.__class__, self.expected))
            return ret
        return wrapped_f

def session(f):
    # Decorator wraps a BaseClient method
    # with session activation / deactivation
    def wrapped_f(*args, **kwargs):
        client = args[0]
            return f(*args, **kwargs)
    return wrapped_f

def normalize_nfc(txt):
    if sys.version_info[0] < 3:
        if isinstance(txt, unicode):
            return unicodedata.normalize('NFC', txt)
        if isinstance(txt, str):
            return unicodedata.normalize('NFC', txt.decode('utf-8'))
        if isinstance(txt, bytes):
            return unicodedata.normalize('NFC', txt.decode('utf-8'))
        if isinstance(txt, str):
            return unicodedata.normalize('NFC', txt)

    raise Exception('unicode/str or bytes/str expected')

class BaseClient(object):
    # Implements very basic layer of sending raw protobuf
    # messages to device and getting its response back.
    def __init__(self, transport, **kwargs):
        self.transport = transport
        super(BaseClient, self).__init__()  # *args, **kwargs)

    def cancel(self):

    def call_raw(self, msg):
        return self.transport.read_blocking()

    def call(self, msg):
        resp = self.call_raw(msg)
        handler_name = "callback_%s" % resp.__class__.__name__
        handler = getattr(self, handler_name, None)

        if handler is not None:
            msg = handler(resp)
            if msg is None:
                raise Exception("Callback %s must return protobuf message, not None" % handler)
            resp = self.call(msg)

        return resp

    def callback_Failure(self, msg):
        if msg.code in (types.Failure_PinInvalid,
                        types.Failure_PinCancelled, types.Failure_PinExpected):
            raise PinException(msg.code, msg.message)

        raise CallException(msg.code, msg.message)

    def close(self):

class VerboseWireMixin(object):
    def call_raw(self, msg):
        log("SENDING " + pprint(msg))
        resp = super(VerboseWireMixin, self).call_raw(msg)
        log("RECEIVED " + pprint(resp))
        return resp

class TextUIMixin(object):
    # This class demonstrates easy test-based UI
    # integration between the device and wallet.
    # You can implement similar functionality
    # by implementing your own GuiMixin with
    # graphical widgets for every type of these callbacks.

    def __init__(self, *args, **kwargs):
        super(TextUIMixin, self).__init__(*args, **kwargs)

    def callback_ButtonRequest(self, msg):
        # log("Sending ButtonAck for %s " % get_buttonrequest_value(msg.code))
        return proto.ButtonAck()

    def callback_RecoveryMatrix(self, msg):
        if self.recovery_matrix_first_pass:
            self.recovery_matrix_first_pass = False
            log("Use the numeric keypad to describe positions.  For the word list use only left and right keys. The layout is:")
            log("    7 8 9     7 | 9")
            log("    4 5 6     4 | 6")
            log("    1 2 3     1 | 3")
        while True:
            character = getch()
            if character in ('\x03', '\x04'):
                return proto.Cancel()

            if character in ('\x08', '\x7f'):
                return proto.WordAck(word='\x08')

            # ignore middle column if only 6 keys requested.
            if msg.type == types.WordRequestType_Matrix6 and character in ('2', '5', '8'):

            if character.isdigit():
                return proto.WordAck(word=character)

    def callback_PinMatrixRequest(self, msg):
        if msg.type == types.PinMatrixRequestType_Current:
            desc = 'current PIN'
        elif msg.type == types.PinMatrixRequestType_NewFirst:
            desc = 'new PIN'
        elif msg.type == types.PinMatrixRequestType_NewSecond:
            desc = 'new PIN again'
            desc = 'PIN'

        log("Use the numeric keypad to describe number positions. The layout is:")
        log("    7 8 9")
        log("    4 5 6")
        log("    1 2 3")
        log("Please enter %s: " % desc)
        pin = getpass.getpass('')
        return proto.PinMatrixAck(pin=pin)

    def callback_PassphraseRequest(self, msg):
        log("Passphrase required: ")
        passphrase = getpass.getpass('')
        log("Confirm your Passphrase: ")
        if passphrase == getpass.getpass(''):
            passphrase = normalize_nfc(passphrase)
            return proto.PassphraseAck(passphrase=passphrase)
            log("Passphrase did not match! ")

    def callback_WordRequest(self, msg):
        if msg.type in (types.WordRequestType_Matrix9,
            return self.callback_RecoveryMatrix(msg)
        log("Enter one word of mnemonic: ")
        word = input()
        if self.expand:
            word = self.mnemonic_wordlist.expand_word(word)
        return proto.WordAck(word=word)

class DebugLinkMixin(object):
    # This class implements automatic responses
    # and other functionality for unit tests
    # for various callbacks, created in order
    # to automatically pass unit tests.
    # This mixing should be used only for purposes
    # of unit testing, because it will fail to work
    # without special DebugLink interface provided
    # by the device.

    def __init__(self, *args, **kwargs):
        super(DebugLinkMixin, self).__init__(*args, **kwargs)
        self.debug = None
        self.in_with_statement = 0
        self.button_wait = 0
        self.screenshot_id = 0

        # Always press Yes and provide correct pin
        self.setup_debuglink(True, True)

        # Do not expect any specific response from device
        self.expected_responses = None

        # Use blank passphrase

    def close(self):
        super(DebugLinkMixin, self).close()
        if self.debug:

    def set_debuglink(self, debug_transport):
        self.debug = DebugLink(debug_transport)

    def set_buttonwait(self, secs):
        self.button_wait = secs

    def __enter__(self):
        # For usage in with/expected_responses
        self.in_with_statement += 1
        return self

    def __exit__(self, _type, value, traceback):
        self.in_with_statement -= 1

        if _type is not None:
            # Another exception raised
            return False

        # return isinstance(value, TypeError)
        # Evaluate missed responses in 'with' statement
        if self.expected_responses is not None and len(self.expected_responses):
            raise Exception("Some of expected responses didn't come from device: %s" %
                            [pprint(x) for x in self.expected_responses])

        # Cleanup
        self.expected_responses = None
        return False

    def set_expected_responses(self, expected):
        if not self.in_with_statement:
            raise Exception("Must be called inside 'with' statement")
        self.expected_responses = expected

    def setup_debuglink(self, button, pin_correct):
        self.button = button  # True -> YES button, False -> NO button
        self.pin_correct = pin_correct

    def set_passphrase(self, passphrase):
        self.passphrase = normalize_nfc(passphrase)

    def set_mnemonic(self, mnemonic):
        self.mnemonic = normalize_nfc(mnemonic).split(' ')

    def call_raw(self, msg):

        if SCREENSHOT and self.debug:
            layout = self.debug.read_layout()
            im = Image.new("RGB", (128, 64))
            pix = im.load()
            for x in range(128):
                for y in range(64):
                    rx, ry = 127 - x, 63 - y
                    if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
                        pix[x, y] = (255, 255, 255)
            im.save('scr%05d.png' % self.screenshot_id)
            self.screenshot_id += 1

        resp = super(DebugLinkMixin, self).call_raw(msg)
        return resp

    def _check_request(self, msg):
        if self.expected_responses is not None:
                expected = self.expected_responses.pop(0)
            except IndexError:
                raise CallException(types.Failure_UnexpectedMessage,
                                    "Got %s, but no message has been expected" % pprint(msg))

            if msg.__class__ != expected.__class__:
                raise CallException(types.Failure_UnexpectedMessage,
                                    "Expected %s, got %s" % (pprint(expected), pprint(msg)))

            fields = expected.ListFields()  # only filled (including extensions)
            for field, value in fields:
                if not msg.HasField(field.name) or getattr(msg, field.name) != value:
                    raise CallException(types.Failure_UnexpectedMessage,
                                        "Expected %s, got %s" % (pprint(expected), pprint(msg)))

    def callback_ButtonRequest(self, msg):
        log("ButtonRequest code: " + get_buttonrequest_value(msg.code))

        log("Pressing button " + str(self.button))
        if self.button_wait:
            log("Waiting %d seconds " % self.button_wait)
        return proto.ButtonAck()

    def callback_PinMatrixRequest(self, msg):
        if self.pin_correct:
            pin = self.debug.read_pin_encoded()
            pin = '444222'
        return proto.PinMatrixAck(pin=pin)

    def callback_PassphraseRequest(self, msg):
        log("Provided passphrase: '%s'" % self.passphrase)
        return proto.PassphraseAck(passphrase=self.passphrase)

    def callback_WordRequest(self, msg):
        (word, pos) = self.debug.read_recovery_word()
        if word != '':
            return proto.WordAck(word=word)
        if pos != 0:
            return proto.WordAck(word=self.mnemonic[pos - 1])

        raise Exception("Unexpected call")

class ProtocolMixin(object):
    PRIME_DERIVATION_FLAG = 0x80000000
    VENDORS = ('bitcointrezor.com', 'trezor.io')

    def __init__(self, *args, **kwargs):
        super(ProtocolMixin, self).__init__(*args, **kwargs)
        self.tx_api = None

    def set_tx_api(self, tx_api):
        self.tx_api = tx_api

    def init_device(self):
        self.features = expect(proto.Features)(self.call)(proto.Initialize())
        if str(self.features.vendor) not in self.VENDORS:
            raise Exception("Unsupported device")

    def _get_local_entropy(self):
        return os.urandom(32)

    def _convert_prime(self, n):
        # Convert minus signs to uint32 with flag
        return [int(abs(x) | self.PRIME_DERIVATION_FLAG) if x < 0 else x for x in n]

    def expand_path(n):
        # Convert string of bip32 path to list of uint32 integers with prime flags
        # 0/-1/1' -> [0, 0x80000001, 0x80000001]
        if not n:
            return []

        n = n.split('/')

        # m/a/b/c => a/b/c
        if n[0] == 'm':
            n = n[1:]

        # coin_name/a/b/c => 44'/SLIP44_constant'/a/b/c
        coins = {
            "Bitcoin": 0,
            "Testnet": 1,
            "Namecoin": 7,
            "Litecoin": 2,
            "Dogecoin": 3,
            "Dash": 5,
            "Ether": 60,
            "EtherClassic": 61,
            "Zcash": 133,
            "Decred": 42
        if n[0] in coins:
            n = ["44'", "%d'" % coins[n[0]]] + n[1:]

        path = []
        for x in n:
            prime = False
            if x.endswith("'"):
                x = x.replace('\'', '')
                prime = True
            if x.startswith('-'):
                prime = True

            x = abs(int(x))

            if prime:
                x |= ProtocolMixin.PRIME_DERIVATION_FLAG


        return path

    def get_public_node(self, n, ecdsa_curve_name=DEFAULT_CURVE, show_display=False, coin_name=None):
        n = self._convert_prime(n)
        if not ecdsa_curve_name:
            ecdsa_curve_name = DEFAULT_CURVE
        return self.call(proto.GetPublicKey(address_n=n, ecdsa_curve_name=ecdsa_curve_name, show_display=show_display, coin_name=coin_name))

    def get_address(self, coin_name, n, show_display=False, multisig=None, script_type=types.SPENDADDRESS):
        n = self._convert_prime(n)
        if multisig:
            return self.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, multisig=multisig, script_type=script_type))
            return self.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, script_type=script_type))

    def ethereum_get_address(self, n, show_display=False, multisig=None):
        n = self._convert_prime(n)
        return self.call(proto.EthereumGetAddress(address_n=n, show_display=show_display))

    def ethereum_sign_tx(self, n, nonce, gas_price, gas_limit, to, value, data=None, chain_id=None):
        def int_to_big_endian(value):
            import rlp.utils
            if value == 0:
                return b''
            return rlp.utils.int_to_big_endian(value)

        n = self._convert_prime(n)

        msg = proto.EthereumSignTx(

        if to:
            msg.to = to

        if data:
            msg.data_length = len(data)
            data, chunk = data[1024:], data[:1024]
            msg.data_initial_chunk = chunk

        if chain_id:
            msg.chain_id = chain_id

        response = self.call(msg)

        while response.HasField('data_length'):
            data_length = response.data_length
            data, chunk = data[data_length:], data[:data_length]
            response = self.call(proto.EthereumTxAck(data_chunk=chunk))

        return response.signature_v, response.signature_r, response.signature_s

    def ethereum_sign_message(self, n, message):
        n = self._convert_prime(n)
        # Convert message to UTF8 NFC (seems to be a bitcoin-qt standard)
        message = normalize_nfc(message).encode('utf-8')
        return self.call(proto.EthereumSignMessage(address_n=n, message=message))

    def ethereum_verify_message(self, address, signature, message):
        # Convert message to UTF8 NFC (seems to be a bitcoin-qt standard)
        message = normalize_nfc(message).encode('utf-8')
            resp = self.call(proto.EthereumVerifyMessage(address=address, signature=signature, message=message))
        except CallException as e:
            resp = e
        if isinstance(resp, proto.Success):
            return True
        return False

    def get_entropy(self, size):
        return self.call(proto.GetEntropy(size=size))

    def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False):
        msg = proto.Ping(message=msg,
        return self.call(msg)

    def get_device_id(self):
        return self.features.device_id

    def apply_settings(self, label=None, language=None, use_passphrase=None, homescreen=None):
        settings = proto.ApplySettings()
        if label is not None:
            settings.label = label
        if language:
            settings.language = language
        if use_passphrase is not None:
            settings.use_passphrase = use_passphrase
        if homescreen is not None:
            settings.homescreen = homescreen

        out = self.call(settings)
        self.init_device()  # Reload Features
        return out

    def apply_flags(self, flags):
        out = self.call(proto.ApplyFlags(flags=flags))
        self.init_device()  # Reload Features
        return out

    def clear_session(self):
        return self.call(proto.ClearSession())

    def change_pin(self, remove=False):
        ret = self.call(proto.ChangePin(remove=remove))
        self.init_device()  # Re-read features
        return ret

    def sign_message(self, coin_name, n, message, script_type=types.SPENDADDRESS):
        n = self._convert_prime(n)
        # Convert message to UTF8 NFC (seems to be a bitcoin-qt standard)
        message = normalize_nfc(message).encode("utf-8")
        return self.call(proto.SignMessage(coin_name=coin_name, address_n=n, message=message, script_type=script_type))

    def sign_identity(self, identity, challenge_hidden, challenge_visual, ecdsa_curve_name=DEFAULT_CURVE):
        return self.call(proto.SignIdentity(identity=identity, challenge_hidden=challenge_hidden, challenge_visual=challenge_visual, ecdsa_curve_name=ecdsa_curve_name))

    def get_ecdh_session_key(self, identity, peer_public_key, ecdsa_curve_name=DEFAULT_CURVE):
        return self.call(proto.GetECDHSessionKey(identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name))

    def set_u2f_counter(self, u2f_counter):
        ret = self.call(proto.SetU2FCounter(u2f_counter=u2f_counter))
        return ret

    def verify_message(self, coin_name, address, signature, message):
        # Convert message to UTF8 NFC (seems to be a bitcoin-qt standard)
        message = normalize_nfc(message).encode("utf-8")
            resp = self.call(proto.VerifyMessage(address=address, signature=signature, message=message, coin_name=coin_name))
        except CallException as e:
            resp = e
        if isinstance(resp, proto.Success):
            return True
        return False

    def encrypt_message(self, pubkey, message, display_only, coin_name, n):
        if coin_name and n:
            n = self._convert_prime(n)
            return self.call(proto.EncryptMessage(pubkey=pubkey, message=message, display_only=display_only, coin_name=coin_name, address_n=n))
            return self.call(proto.EncryptMessage(pubkey=pubkey, message=message, display_only=display_only))

    def decrypt_message(self, n, nonce, message, msg_hmac):
        n = self._convert_prime(n)
        return self.call(proto.DecryptMessage(address_n=n, nonce=nonce, message=message, hmac=msg_hmac))

    def encrypt_keyvalue(self, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
        n = self._convert_prime(n)
        return self.call(proto.CipherKeyValue(address_n=n,

    def decrypt_keyvalue(self, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
        n = self._convert_prime(n)
        return self.call(proto.CipherKeyValue(address_n=n,

    def estimate_tx_size(self, coin_name, inputs, outputs):
        msg = proto.EstimateTxSize()
        msg.coin_name = coin_name
        msg.inputs_count = len(inputs)
        msg.outputs_count = len(outputs)
        return self.call(msg)

    def _prepare_simple_sign_tx(self, coin_name, inputs, outputs):
        msg = proto.SimpleSignTx()
        msg.coin_name = coin_name

        known_hashes = []
        for inp in inputs:
            if inp.prev_hash in known_hashes:

            tx = msg.transactions.add()
            if self.tx_api:
                raise Exception('TX_API not defined')

        return msg

    def simple_sign_tx(self, coin_name, inputs, outputs):
        msg = self._prepare_simple_sign_tx(coin_name, inputs, outputs)
        return self.call(msg).serialized.serialized_tx

    def _prepare_sign_tx(self, coin_name, inputs, outputs):
        tx = types.TransactionType()

        txes = {}
        txes[b''] = tx

        known_hashes = []
        for inp in inputs:
            if inp.prev_hash in known_hashes:

            if self.tx_api:
                txes[inp.prev_hash] = self.tx_api.get_tx(binascii.hexlify(inp.prev_hash).decode('utf-8'))
                raise Exception('TX_API not defined')

        return txes

    def sign_tx(self, coin_name, inputs, outputs, version=None, lock_time=None, debug_processor=None):

        start = time.time()
        txes = self._prepare_sign_tx(coin_name, inputs, outputs)

        # Prepare and send initial message
        tx = proto.SignTx()
        tx.inputs_count = len(inputs)
        tx.outputs_count = len(outputs)
        tx.coin_name = coin_name
        if version is not None:
            tx.version = version
        if lock_time is not None:
            tx.lock_time = lock_time
        res = self.call(tx)

        # Prepare structure for signatures
        signatures = [None] * len(inputs)
        serialized_tx = b''

        counter = 0
        while True:
            counter += 1

            if isinstance(res, proto.Failure):
                raise CallException("Signing failed")

            if not isinstance(res, proto.TxRequest):
                raise CallException("Unexpected message")

            # If there's some part of signed transaction, let's add it
            if res.HasField('serialized') and res.serialized.HasField('serialized_tx'):
                log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx))
                serialized_tx += res.serialized.serialized_tx

            if res.HasField('serialized') and res.serialized.HasField('signature_index'):
                if signatures[res.serialized.signature_index] is not None:
                    raise Exception("Signature for index %d already filled" % res.serialized.signature_index)
                signatures[res.serialized.signature_index] = res.serialized.signature

            if res.request_type == types.TXFINISHED:
                # Device didn't ask for more information, finish workflow

            # Device asked for one more information, let's process it.
            current_tx = txes[res.details.tx_hash]

            if res.request_type == types.TXMETA:
                msg = types.TransactionType()
                msg.version = current_tx.version
                msg.lock_time = current_tx.lock_time
                msg.inputs_cnt = len(current_tx.inputs)
                if res.details.tx_hash:
                    msg.outputs_cnt = len(current_tx.bin_outputs)
                    msg.outputs_cnt = len(current_tx.outputs)
                msg.extra_data_len = len(current_tx.extra_data)
                res = self.call(proto.TxAck(tx=msg))

            elif res.request_type == types.TXINPUT:
                msg = types.TransactionType()
                msg.inputs.extend([current_tx.inputs[res.details.request_index], ])
                if debug_processor is not None:
                    # If debug_processor function is provided,
                    # pass thru it the request and prepared response.
                    # This is useful for unit tests, see test_msg_signtx
                    msg = debug_processor(res, msg)

                res = self.call(proto.TxAck(tx=msg))

            elif res.request_type == types.TXOUTPUT:
                msg = types.TransactionType()
                if res.details.tx_hash:
                    msg.bin_outputs.extend([current_tx.bin_outputs[res.details.request_index], ])
                    msg.outputs.extend([current_tx.outputs[res.details.request_index], ])

                if debug_processor is not None:
                    # If debug_processor function is provided,
                    # pass thru it the request and prepared response.
                    # This is useful for unit tests, see test_msg_signtx
                    msg = debug_processor(res, msg)

                res = self.call(proto.TxAck(tx=msg))

            elif res.request_type == types.TXEXTRADATA:
                o, l = res.details.extra_data_offset, res.details.extra_data_len
                msg = types.TransactionType()
                msg.extra_data = current_tx.extra_data[o:o + l]
                res = self.call(proto.TxAck(tx=msg))

        if None in signatures:
            raise Exception("Some signatures are missing!")

        log("SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" %
            (time.time() - start, counter, len(serialized_tx)))

        return (signatures, serialized_tx)

    def wipe_device(self):
        ret = self.call(proto.WipeDevice())
        return ret

    def recovery_device(self, word_count, passphrase_protection, pin_protection, label, language, type=types.RecoveryDeviceType_ScrambledWords, expand=False, dry_run=False):
        if self.features.initialized and not dry_run:
            raise Exception("Device is initialized already. Call wipe_device() and try again.")

        if word_count not in (12, 18, 24):
            raise Exception("Invalid word count. Use 12/18/24")

        self.recovery_matrix_first_pass = True

        self.expand = expand
        if self.expand:
            # optimization to load the wordlist once, instead of for each recovery word
            self.mnemonic_wordlist = Mnemonic('english')

        res = self.call(proto.RecoveryDevice(

        return res

    def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language, u2f_counter=0, skip_backup=False):
        if self.features.initialized:
            raise Exception("Device is initialized already. Call wipe_device() and try again.")

        # Begin with device reset workflow
        msg = proto.ResetDevice(display_random=display_random,

        resp = self.call(msg)
        if not isinstance(resp, proto.EntropyRequest):
            raise Exception("Invalid response, expected EntropyRequest")

        external_entropy = self._get_local_entropy()
        log("Computer generated entropy: " + binascii.hexlify(external_entropy).decode('ascii'))
        ret = self.call(proto.EntropyAck(entropy=external_entropy))
        return ret

    def backup_device(self):
        ret = self.call(proto.BackupDevice())
        return ret

    def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language='english', skip_checksum=False, expand=False):
        # Convert mnemonic to UTF8 NKFD
        mnemonic = Mnemonic.normalize_string(mnemonic)

        # Convert mnemonic to ASCII stream
        mnemonic = normalize_nfc(mnemonic)

        m = Mnemonic('english')

        if expand:
            mnemonic = m.expand(mnemonic)

        if not skip_checksum and not m.check(mnemonic):
            raise Exception("Invalid mnemonic checksum")

        if self.features.initialized:
            raise Exception("Device is initialized already. Call wipe_device() and try again.")

        resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin,
        return resp

    def load_device_by_xprv(self, xprv, pin, passphrase_protection, label, language):
        if self.features.initialized:
            raise Exception("Device is initialized already. Call wipe_device() and try again.")

        if xprv[0:4] not in ('xprv', 'tprv'):
            raise Exception("Unknown type of xprv")

        if len(xprv) < 100 and len(xprv) > 112:
            raise Exception("Invalid length of xprv")

        node = types.HDNodeType()
        data = binascii.hexlify(tools.b58decode(xprv, None))

        if data[90:92] != b'00':
            raise Exception("Contain invalid private key")

        checksum = binascii.hexlify(hashlib.sha256(hashlib.sha256(binascii.unhexlify(data[:156])).digest()).digest()[:4])
        if checksum != data[156:]:
            raise Exception("Checksum doesn't match")

        # version 0488ade4
        # depth 00
        # fingerprint 00000000
        # child_num 00000000
        # chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508
        # privkey   00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35
        # checksum e77e9d71

        node.depth = int(data[8:10], 16)
        node.fingerprint = int(data[10:18], 16)
        node.child_num = int(data[18:26], 16)
        node.chain_code = binascii.unhexlify(data[26:90])
        node.private_key = binascii.unhexlify(data[92:156])  # skip 0x00 indicating privkey

        resp = self.call(proto.LoadDevice(node=node,
        return resp

    def firmware_update(self, fp):
        if self.features.bootloader_mode is False:
            raise Exception("Device must be in bootloader mode")

        data = fp.read()

        resp = self.call(proto.FirmwareErase(length=len(data)))
        if isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError:
            return False

        # TREZORv1 method
        if isinstance(resp, proto.Success):
            fingerprint = hashlib.sha256(data[256:]).hexdigest()
            log("Firmware fingerprint: " + fingerprint)
            resp = self.call(proto.FirmwareUpload(payload=data))
            if isinstance(resp, proto.Success):
                return True
            elif isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError:
                return False
            raise Exception("Unexpected result %s" % resp)

        # TREZORv2 method
        if isinstance(resp, proto.FirmwareRequest):
            import pyblake2
            while True:
                payload = data[resp.offset:resp.offset + resp.length]
                digest = pyblake2.blake2s(payload).digest()
                resp = self.call(proto.FirmwareUpload(payload=payload, hash=digest))
                if isinstance(resp, proto.FirmwareRequest):
                elif isinstance(resp, proto.Success):
                    return True
                elif isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError:
                    return False
                raise Exception("Unexpected result %s" % resp)

        raise Exception("Unexpected message %s" % resp)

    def self_test(self):
        if self.features.bootloader_mode is False:
            raise Exception("Device must be in bootloader mode")

        return self.call(proto.SelfTest(payload=b'\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC'))

class TrezorClient(ProtocolMixin, TextUIMixin, BaseClient):

class TrezorClientVerbose(ProtocolMixin, TextUIMixin, VerboseWireMixin, BaseClient):

class TrezorClientDebugLink(ProtocolMixin, DebugLinkMixin, VerboseWireMixin, BaseClient):