mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-03 03:11:17 +00:00
Heavily refactored TrezorClient
This commit is contained in:
parent
36b0c8095d
commit
558f61e635
@ -1,32 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import binascii
|
import binascii
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
import tools
|
import tools
|
||||||
import messages_pb2 as proto
|
import messages_pb2 as proto
|
||||||
import types_pb2 as types
|
import types_pb2 as types
|
||||||
from api_blockchain import BlockchainApi
|
from trezorlib.debuglink import DebugLink
|
||||||
|
|
||||||
# monkeypatching: text formatting of protobuf messages
|
# monkeypatching: text formatting of protobuf messages
|
||||||
tools.monkeypatch_google_protobuf_text_format()
|
tools.monkeypatch_google_protobuf_text_format()
|
||||||
|
|
||||||
def show_message(message):
|
def get_buttonrequest_value(code):
|
||||||
print "MESSAGE FROM DEVICE:", message
|
# Converts integer code to its string representation of ButtonRequestType
|
||||||
|
return [ k for k, v in types.ButtonRequestType.items() if v == code][0]
|
||||||
def show_input(input_text, message=None):
|
|
||||||
if message:
|
|
||||||
print "QUESTION FROM DEVICE:", message
|
|
||||||
return raw_input(input_text)
|
|
||||||
|
|
||||||
def pin_func(input_text, message=None):
|
|
||||||
return show_input(input_text, message)
|
|
||||||
|
|
||||||
def passphrase_func(input_text):
|
|
||||||
return show_input(input_text)
|
|
||||||
|
|
||||||
def word_func():
|
|
||||||
return raw_input("Enter one word of mnemonic: ")
|
|
||||||
|
|
||||||
class CallException(Exception):
|
class CallException(Exception):
|
||||||
def __init__(self, code, message):
|
def __init__(self, code, message):
|
||||||
@ -36,39 +22,189 @@ class CallException(Exception):
|
|||||||
class PinException(CallException):
|
class PinException(CallException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
PRIME_DERIVATION_FLAG = 0x80000000
|
class field(object):
|
||||||
|
# Decorator extracts single value from
|
||||||
|
# protobuf object. If the field is not
|
||||||
|
# present, raises an exception.
|
||||||
|
def __init__(self, field):
|
||||||
|
self.field = field
|
||||||
|
|
||||||
class TrezorClient(object):
|
def __call__(self, f):
|
||||||
|
def wrapped_f(*args, **kwargs):
|
||||||
|
ret = f(*args, **kwargs)
|
||||||
|
ret.HasField(self.field)
|
||||||
|
return getattr(ret, self.field)
|
||||||
|
return wrapped_f
|
||||||
|
|
||||||
def __init__(self, transport, debuglink=None,
|
class expect(object):
|
||||||
message_func=show_message, input_func=show_input,
|
# Decorator checks if the method
|
||||||
pin_func=pin_func, passphrase_func=passphrase_func,
|
# returned one of expected protobuf messages
|
||||||
word_func=word_func, blockchain_api=None, debug=False):
|
# or raises an exception
|
||||||
|
def __init__(self, *expected):
|
||||||
|
self.expected = expected
|
||||||
|
|
||||||
|
def __call__(self, f):
|
||||||
|
def wrapped_f(*args, **kwargs):
|
||||||
|
ret = f(*args, **kwargs)
|
||||||
|
if not isinstance(ret, self.expected):
|
||||||
|
raise Exception("Got %s, expected %s" % (ret.__class__, self.expected))
|
||||||
|
return ret
|
||||||
|
return wrapped_f
|
||||||
|
|
||||||
|
class BaseClient(object):
|
||||||
|
# Implements very basic layer of sending raw protobuf
|
||||||
|
# messages to device and getting its response back.
|
||||||
|
def __init__(self, transport, *args, **kwargs):
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.debuglink = debuglink
|
super(BaseClient, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.message_func = message_func
|
def call(self, msg):
|
||||||
self.input_func = input_func
|
try:
|
||||||
self.pin_func = pin_func
|
self.transport.session_begin()
|
||||||
self.passphrase_func = passphrase_func
|
|
||||||
self.word_func = word_func
|
|
||||||
|
|
||||||
self.debug = debug
|
self.transport.write(msg)
|
||||||
|
resp = self.transport.read_blocking()
|
||||||
|
|
||||||
if blockchain_api:
|
handler_name = "callback_%s" % resp.__class__.__name__
|
||||||
self.blockchain = blockchain_api
|
handler = getattr(self, handler_name, None)
|
||||||
|
|
||||||
|
if handler != None:
|
||||||
|
msg = handler(resp)
|
||||||
|
if msg == None:
|
||||||
|
raise Exception("Callback %s must return protobug message, not None" % handler)
|
||||||
|
|
||||||
|
resp = self.call(msg)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self.transport.session_end()
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def callback_Failure(self, msg):
|
||||||
|
if msg.code in (types.Failure_PinInvalid,
|
||||||
|
types.Failure_PinCancelled, types.Failure_PinExpected):
|
||||||
|
raise PinException(msg.code, msg.message)
|
||||||
|
|
||||||
|
raise CallException(msg.code, msg.message)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
class TextUIMixin(object):
|
||||||
|
# This class demonstrates easy test-based UI
|
||||||
|
# integration between the device and wallet.
|
||||||
|
# You can implement similar functionality
|
||||||
|
# by implementing your own GuiMixin with
|
||||||
|
# graphical widgets for every type of these callbacks.
|
||||||
|
|
||||||
|
def callback_ButtonRequest(self, msg):
|
||||||
|
print "Sending ButtonAck for %s " % get_buttonrequest_value(msg.code)
|
||||||
|
return proto.ButtonAck()
|
||||||
|
|
||||||
|
def callback_PinMatrixRequest(self, msg):
|
||||||
|
pin = raw_input("PIN required: %s " % msg.message)
|
||||||
|
return proto.PinMatrixAck(pin=pin)
|
||||||
|
|
||||||
|
def callback_PassphraseRequest(self, msg):
|
||||||
|
passphrase = raw_input("Passphrase required: %s " % msg.message)
|
||||||
|
return proto.PassphraseAck(passphrase=passphrase)
|
||||||
|
|
||||||
|
def callback_WordRequest(self, msg):
|
||||||
|
word = raw_input("Enter one word of mnemonic: ")
|
||||||
|
return proto.WordAck(word=word)
|
||||||
|
|
||||||
|
class DebugLinkMixin(object):
|
||||||
|
# This class implements automatic responses
|
||||||
|
# and other functionality for unit tests
|
||||||
|
# for various callbacks, created in order
|
||||||
|
# to automatically pass unit tests.
|
||||||
|
#
|
||||||
|
# This mixing should be used only for purposes
|
||||||
|
# of unit testing, because it will fail to work
|
||||||
|
# without special DebugLink interface provided
|
||||||
|
# by the device.
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(DebugLinkMixin, self).__init__(*args, **kwargs)
|
||||||
|
self.debug = None
|
||||||
|
|
||||||
|
# Always press Yes and provide correct pin
|
||||||
|
self.setup_debuglink(True, True)
|
||||||
|
|
||||||
|
# Do not expect any specific ButtonRequest
|
||||||
|
self.set_expected_buttonrequests(None)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
super(DebugLinkMixin, self).close()
|
||||||
|
if self.debug:
|
||||||
|
self.debug.close()
|
||||||
|
|
||||||
|
def set_debuglink(self, debug_transport):
|
||||||
|
self.debug = DebugLink(debug_transport)
|
||||||
|
|
||||||
|
def set_expected_buttonrequests(self, expected):
|
||||||
|
self.expected_buttonrequests = expected
|
||||||
|
|
||||||
|
def setup_debuglink(self, button, pin_correct):
|
||||||
|
self.button = button # True -> YES button, False -> NO button
|
||||||
|
self.pin_correct = pin_correct
|
||||||
|
|
||||||
|
def callback_ButtonRequest(self, msg):
|
||||||
|
if self.expected_buttonrequests != None:
|
||||||
|
try:
|
||||||
|
expected = self.expected_buttonrequests.pop(0)
|
||||||
|
if msg.code != expected:
|
||||||
|
raise CallException(types.Failure_Other,
|
||||||
|
"Expected %s, got %s" % \
|
||||||
|
(get_buttonrequest_value(expected),
|
||||||
|
get_buttonrequest_value(msg.code)))
|
||||||
|
except IndexError:
|
||||||
|
raise CallException(types.Failure_Other,
|
||||||
|
"Got %s, but no ButtonRequest has been expected" % \
|
||||||
|
get_buttonrequest_value(msg.code))
|
||||||
|
|
||||||
|
print "ButtonRequest code:", get_buttonrequest_value(msg.code)
|
||||||
|
|
||||||
|
print "Pressing button", self.button
|
||||||
|
self.debug.press_button(self.button)
|
||||||
|
return proto.ButtonAck()
|
||||||
|
|
||||||
|
def callback_PinMatrixRequest(self, msg):
|
||||||
|
if self.pin_correct:
|
||||||
|
pin = self.debug.read_pin_encoded()
|
||||||
else:
|
else:
|
||||||
self.blockchain = BlockchainApi()
|
pin = '444222'
|
||||||
|
return proto.PinMatrixAck(pin=pin)
|
||||||
|
|
||||||
self.setup_debuglink()
|
def callback_PassphraseRequest(self, msg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def callback_WordRequest(self, msg):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ProtocolMixin(object):
|
||||||
|
PRIME_DERIVATION_FLAG = 0x80000000
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(ProtocolMixin, self).__init__() # *args, **kwargs)
|
||||||
self.init_device()
|
self.init_device()
|
||||||
|
|
||||||
|
def get_tx_func_placeholder(txhash):
|
||||||
|
raise Exception("Please call set_tx_func() first.")
|
||||||
|
self.get_tx_func = get_tx_func_placeholder
|
||||||
|
|
||||||
|
def set_tx_func(self, tx_func):
|
||||||
|
self.get_tx_func = tx_func
|
||||||
|
|
||||||
|
def init_device(self):
|
||||||
|
self.features = expect(proto.Features)(self.call)(proto.Initialize())
|
||||||
|
|
||||||
def _get_local_entropy(self):
|
def _get_local_entropy(self):
|
||||||
return os.urandom(32)
|
return os.urandom(32)
|
||||||
|
|
||||||
def _convert_prime(self, n):
|
def _convert_prime(self, n):
|
||||||
# Convert minus signs to uint32 with flag
|
# Convert minus signs to uint32 with flag
|
||||||
return [ int(abs(x) | PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ]
|
return [ int(abs(x) | self.PRIME_DERIVATION_FLAG) if x < 0 else x for x in n ]
|
||||||
|
|
||||||
def expand_path(self, n):
|
def expand_path(self, n):
|
||||||
# Convert string of bip32 path to list of uint32 integers with prime flags
|
# Convert string of bip32 path to list of uint32 integers with prime flags
|
||||||
@ -89,40 +225,42 @@ class TrezorClient(object):
|
|||||||
x = abs(int(x))
|
x = abs(int(x))
|
||||||
|
|
||||||
if prime:
|
if prime:
|
||||||
x |= PRIME_DERIVATION_FLAG
|
x |= self.PRIME_DERIVATION_FLAG
|
||||||
|
|
||||||
path.append(x)
|
path.append(x)
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def init_device(self):
|
@field('node')
|
||||||
self.features = self.call(proto.Initialize(), proto.Features)
|
@expect(proto.PublicKey)
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.transport.close()
|
|
||||||
if self.debuglink:
|
|
||||||
self.debuglink.transport.close()
|
|
||||||
|
|
||||||
def get_public_node(self, n):
|
def get_public_node(self, n):
|
||||||
return self.call(proto.GetPublicKey(address_n=n), proto.PublicKey).node
|
return self.call(proto.GetPublicKey(address_n=n))
|
||||||
|
|
||||||
|
@field('address')
|
||||||
|
@expect(proto.Address)
|
||||||
def get_address(self, coin_name, n):
|
def get_address(self, coin_name, n):
|
||||||
n = self._convert_prime(n)
|
n = self._convert_prime(n)
|
||||||
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name), proto.Address).address
|
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name))
|
||||||
|
|
||||||
|
@field('entropy')
|
||||||
|
@expect(proto.Entropy)
|
||||||
def get_entropy(self, size):
|
def get_entropy(self, size):
|
||||||
return self.call(proto.GetEntropy(size=size), proto.Entropy).entropy
|
return self.call(proto.GetEntropy(size=size))
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False):
|
def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False):
|
||||||
msg = proto.Ping(message=msg,
|
msg = proto.Ping(message=msg,
|
||||||
button_protection=button_protection,
|
button_protection=button_protection,
|
||||||
pin_protection=pin_protection,
|
pin_protection=pin_protection,
|
||||||
passphrase_protection=passphrase_protection)
|
passphrase_protection=passphrase_protection)
|
||||||
return self.call(msg, proto.Success).message
|
return self.call(msg)
|
||||||
|
|
||||||
def get_device_id(self):
|
def get_device_id(self):
|
||||||
return self.features.device_id
|
return self.features.device_id
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
def apply_settings(self, label=None, language=None):
|
def apply_settings(self, label=None, language=None):
|
||||||
settings = proto.ApplySettings()
|
settings = proto.ApplySettings()
|
||||||
if label != None:
|
if label != None:
|
||||||
@ -130,27 +268,203 @@ class TrezorClient(object):
|
|||||||
if language:
|
if language:
|
||||||
settings.language = language
|
settings.language = language
|
||||||
|
|
||||||
out = self.call(settings, proto.Success).message
|
out = self.call(settings)
|
||||||
self.init_device() # Reload Features
|
self.init_device() # Reload Features
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
def change_pin(self, remove=False):
|
def change_pin(self, remove=False):
|
||||||
ret = self.call(proto.ChangePin(remove=remove))
|
ret = self.call(proto.ChangePin(remove=remove))
|
||||||
self.init_device() # Re-read features
|
self.init_device() # Re-read features
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@expect(proto.MessageSignature)
|
||||||
|
def sign_message(self, n, message):
|
||||||
|
n = self._convert_prime(n)
|
||||||
|
return self.call(proto.SignMessage(address_n=n, message=message))
|
||||||
|
|
||||||
|
def verify_message(self, address, signature, message):
|
||||||
|
try:
|
||||||
|
resp = self.call(proto.VerifyMessage(address=address, signature=signature, message=message))
|
||||||
|
except CallException as e:
|
||||||
|
resp = e
|
||||||
|
if isinstance(resp, proto.Success):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@field('tx_size')
|
||||||
|
@expect(proto.TxSize)
|
||||||
|
def estimate_tx_size(self, coin_name, inputs, outputs):
|
||||||
|
msg = proto.EstimateTxSize()
|
||||||
|
msg.coin_name = coin_name
|
||||||
|
msg.inputs_count = len(inputs)
|
||||||
|
msg.outputs_count = len(outputs)
|
||||||
|
return self.call(msg)
|
||||||
|
|
||||||
|
def _prepare_simple_sign_tx(self, coin_name, inputs, outputs):
|
||||||
|
msg = proto.SimpleSignTx()
|
||||||
|
msg.coin_name = coin_name
|
||||||
|
msg.inputs.extend(inputs)
|
||||||
|
msg.outputs.extend(outputs)
|
||||||
|
|
||||||
|
known_hashes = []
|
||||||
|
for inp in inputs:
|
||||||
|
if inp.prev_hash in known_hashes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tx = msg.transactions.add()
|
||||||
|
tx.CopyFrom(self.get_tx_func(binascii.hexlify(inp.prev_hash)))
|
||||||
|
known_hashes.append(inp.prev_hash)
|
||||||
|
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@field('serialized_tx')
|
||||||
|
@expect(proto.TxRequest)
|
||||||
|
def simple_sign_tx(self, coin_name, inputs, outputs):
|
||||||
|
# TODO Deserialize tx and check if inputs/outputs fits
|
||||||
|
msg = self._prepare_simple_sign_tx(coin_name, inputs, outputs)
|
||||||
|
return self.call(msg)
|
||||||
|
|
||||||
|
def sign_tx(self, coin_name, inputs, outputs):
|
||||||
|
# Temporary solution, until streaming is implemented in the firmware
|
||||||
|
return self.simple_sign_tx(coin_name, inputs, outputs)
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
|
def wipe_device(self):
|
||||||
|
ret = self.call(proto.WipeDevice())
|
||||||
|
self.init_device()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
|
def recovery_device(self, word_count, passphrase_protection, pin_protection, label, language):
|
||||||
|
if self.features.initialized:
|
||||||
|
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
||||||
|
|
||||||
|
if word_count not in (12, 18, 24):
|
||||||
|
raise Exception("Invalid word count. Use 12/18/24")
|
||||||
|
|
||||||
|
res = self.call(proto.RecoveryDevice(word_count=int(word_count),
|
||||||
|
passphrase_protection=bool(passphrase_protection),
|
||||||
|
pin_protection=bool(pin_protection),
|
||||||
|
label=label,
|
||||||
|
language=language,
|
||||||
|
enforce_wordlist=True))
|
||||||
|
|
||||||
|
self.init_device()
|
||||||
|
return res
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
|
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
|
||||||
|
if self.features.initialized:
|
||||||
|
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
||||||
|
|
||||||
|
# Begin with device reset workflow
|
||||||
|
msg = proto.ResetDevice(display_random=display_random,
|
||||||
|
strength=strength,
|
||||||
|
language=language,
|
||||||
|
passphrase_protection=bool(passphrase_protection),
|
||||||
|
pin_protection=bool(pin_protection),
|
||||||
|
label=label)
|
||||||
|
|
||||||
|
resp = self.call(msg)
|
||||||
|
if not isinstance(resp, proto.EntropyRequest):
|
||||||
|
raise Exception("Invalid response, expected EntropyRequest")
|
||||||
|
|
||||||
|
external_entropy = self._get_local_entropy()
|
||||||
|
print "Computer generated entropy:", binascii.hexlify(external_entropy)
|
||||||
|
return self.call(proto.EntropyAck(entropy=external_entropy))
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
|
def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language):
|
||||||
|
if self.features.initialized:
|
||||||
|
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
||||||
|
|
||||||
|
resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin,
|
||||||
|
passphrase_protection=passphrase_protection,
|
||||||
|
language=language,
|
||||||
|
label=label))
|
||||||
|
self.init_device()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@field('message')
|
||||||
|
@expect(proto.Success)
|
||||||
|
def load_device_by_xprv(self, xprv, pin, passphrase_protection, label):
|
||||||
|
if self.features.initialized:
|
||||||
|
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
||||||
|
|
||||||
|
if xprv[0:4] not in ('xprv', 'tprv'):
|
||||||
|
raise Exception("Unknown type of xprv")
|
||||||
|
|
||||||
|
if len(xprv) < 100 and len(xprv) > 112:
|
||||||
|
raise Exception("Invalid length of xprv")
|
||||||
|
|
||||||
|
node = types.HDNodeType()
|
||||||
|
data = tools.b58decode(xprv, None).encode('hex')
|
||||||
|
|
||||||
|
if data[90:92] != '00':
|
||||||
|
raise Exception("Contain invalid private key")
|
||||||
|
|
||||||
|
checksum = hashlib.sha256(hashlib.sha256(binascii.unhexlify(data[:156])).digest()).hexdigest()[:8]
|
||||||
|
if checksum != data[156:]:
|
||||||
|
raise Exception("Checksum doesn't match")
|
||||||
|
|
||||||
|
# version 0488ade4
|
||||||
|
# depth 00
|
||||||
|
# fingerprint 00000000
|
||||||
|
# child_num 00000000
|
||||||
|
# chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508
|
||||||
|
# privkey 00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35
|
||||||
|
# checksum e77e9d71
|
||||||
|
|
||||||
|
node.version = int(data[0:8], 16)
|
||||||
|
node.depth = int(data[8:10], 16)
|
||||||
|
node.fingerprint = int(data[10:18], 16)
|
||||||
|
node.child_num = int(data[18:26], 16)
|
||||||
|
node.chain_code = data[26:90].decode('hex')
|
||||||
|
node.private_key = data[92:156].decode('hex') # skip 0x00 indicating privkey
|
||||||
|
|
||||||
|
resp = self.call(proto.LoadDevice(node=node,
|
||||||
|
pin=pin,
|
||||||
|
passphrase_protection=passphrase_protection,
|
||||||
|
language='english',
|
||||||
|
label=label))
|
||||||
|
self.init_device()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def firmware_update(self, fp):
|
||||||
|
if self.features.bootloader_mode == False:
|
||||||
|
raise Exception("Device must be in bootloader mode")
|
||||||
|
|
||||||
|
resp = self.call(proto.FirmwareErase())
|
||||||
|
if isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
resp = self.call(proto.FirmwareUpload(payload=fp.read()))
|
||||||
|
if isinstance(resp, proto.Success):
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
raise Exception("Unexpected result " % resp)
|
||||||
|
|
||||||
|
class TrezorClient(BaseClient, ProtocolMixin, TextUIMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TrezorDebugClient(BaseClient, ProtocolMixin, DebugLinkMixin):
|
||||||
|
pass
|
||||||
|
|
||||||
|
'''
|
||||||
|
class TrezorClient(object):
|
||||||
def _pprint(self, msg):
|
def _pprint(self, msg):
|
||||||
ser = msg.SerializeToString()
|
ser = msg.SerializeToString()
|
||||||
return "<%s> (%d bytes):\n%s" % (msg.__class__.__name__, len(ser), msg)
|
return "<%s> (%d bytes):\n%s" % (msg.__class__.__name__, len(ser), msg)
|
||||||
|
|
||||||
def setup_debuglink(self, button=None, pin_correct=False):
|
|
||||||
self.debug_button = button
|
|
||||||
self.debug_pin = pin_correct
|
|
||||||
|
|
||||||
def _get_buttonrequest_value(self, code):
|
|
||||||
return [ k for k, v in types.ButtonRequestType.items() if v == code][0]
|
|
||||||
|
|
||||||
def call(self, msg, expected=None, expected_buttonrequests=None):
|
def call(self, msg, expected=None, expected_buttonrequests=None):
|
||||||
# TODO split this into normal and debug mode
|
# TODO split this into normal and debug mode
|
||||||
if self.debug:
|
if self.debug:
|
||||||
@ -201,7 +515,7 @@ class TrezorClient(object):
|
|||||||
|
|
||||||
if isinstance(resp, proto.PassphraseRequest):
|
if isinstance(resp, proto.PassphraseRequest):
|
||||||
passphrase = self.passphrase_func("Passphrase required: ")
|
passphrase = self.passphrase_func("Passphrase required: ")
|
||||||
msg2 = proto.PassphraseAck(passphrase=passphrase)
|
ms(object)g2 = proto.PassphraseAck(passphrase=passphrase)
|
||||||
return self.call(msg2, expected=expected, expected_buttonrequests=expected_buttonrequests)
|
return self.call(msg2, expected=expected, expected_buttonrequests=expected_buttonrequests)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
@ -229,55 +543,8 @@ class TrezorClient(object):
|
|||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
def sign_message(self, n, message):
|
|
||||||
n = self._convert_prime(n)
|
|
||||||
return self.call(proto.SignMessage(address_n=n, message=message))
|
|
||||||
|
|
||||||
def verify_message(self, address, signature, message):
|
|
||||||
try:
|
|
||||||
resp = self.call(proto.VerifyMessage(address=address, signature=signature, message=message))
|
|
||||||
if isinstance(resp, proto.Success):
|
|
||||||
return True
|
|
||||||
except CallException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def estimate_tx_size(self, coin_name, inputs, outputs):
|
|
||||||
msg = proto.EstimateTxSize()
|
|
||||||
msg.coin_name = coin_name
|
|
||||||
msg.inputs_count = len(inputs)
|
|
||||||
msg.outputs_count = len(outputs)
|
|
||||||
res = self.call(msg)
|
|
||||||
return res.tx_size
|
|
||||||
|
|
||||||
def _prepare_simple_sign_tx(self, coin_name, inputs, outputs):
|
|
||||||
msg = proto.SimpleSignTx()
|
|
||||||
msg.coin_name = coin_name
|
|
||||||
msg.inputs.extend(inputs)
|
|
||||||
msg.outputs.extend(outputs)
|
|
||||||
|
|
||||||
known_hashes = []
|
|
||||||
for inp in inputs:
|
|
||||||
if inp.prev_hash in known_hashes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
tx = msg.transactions.add()
|
|
||||||
tx.CopyFrom(self.blockchain.get_tx(binascii.hexlify(inp.prev_hash)))
|
|
||||||
known_hashes.append(inp.prev_hash)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def simple_sign_tx(self, coin_name, inputs, outputs):
|
|
||||||
msg = self._prepare_simple_sign_tx(coin_name, inputs, outputs)
|
|
||||||
return self.call(msg)
|
|
||||||
|
|
||||||
def sign_tx(self, coin_name, inputs, outputs):
|
|
||||||
# Temporary solution, until streaming is implemented in the firmware
|
|
||||||
return self.simple_sign_tx(coin_name, inputs, outputs)
|
|
||||||
|
|
||||||
def _sign_tx(self, coin_name, inputs, outputs):
|
def _sign_tx(self, coin_name, inputs, outputs):
|
||||||
'''
|
''
|
||||||
inputs: list of TxInput
|
inputs: list of TxInput
|
||||||
outputs: list of TxOutput
|
outputs: list of TxOutput
|
||||||
|
|
||||||
@ -295,7 +562,7 @@ class TrezorClient(object):
|
|||||||
script_type=proto.PAYTOADDRESS,
|
script_type=proto.PAYTOADDRESS,
|
||||||
#script_args=
|
#script_args=
|
||||||
)
|
)
|
||||||
'''
|
''
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
@ -352,125 +619,4 @@ class TrezorClient(object):
|
|||||||
(time.time() - start, counter, len(serialized_tx))
|
(time.time() - start, counter, len(serialized_tx))
|
||||||
|
|
||||||
return (signatures, serialized_tx)
|
return (signatures, serialized_tx)
|
||||||
|
'''
|
||||||
def wipe_device(self):
|
|
||||||
ret = self.call(proto.WipeDevice())
|
|
||||||
self.init_device()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def recovery_device(self, word_count, passphrase_protection, pin_protection, label, language):
|
|
||||||
if word_count not in (12, 18, 24):
|
|
||||||
raise Exception("Invalid word count. Use 12/18/24")
|
|
||||||
|
|
||||||
res = self.call(proto.RecoveryDevice(word_count=int(word_count),
|
|
||||||
passphrase_protection=bool(passphrase_protection),
|
|
||||||
pin_protection=bool(pin_protection),
|
|
||||||
label=label,
|
|
||||||
language=language,
|
|
||||||
enforce_wordlist=True))
|
|
||||||
|
|
||||||
while isinstance(res, proto.WordRequest):
|
|
||||||
word = self.word_func()
|
|
||||||
res = self.call(proto.WordAck(word=word))
|
|
||||||
|
|
||||||
if not isinstance(res, proto.Success):
|
|
||||||
raise Exception("Recovery device failed")
|
|
||||||
|
|
||||||
self.init_device()
|
|
||||||
return True
|
|
||||||
|
|
||||||
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language):
|
|
||||||
if self.features.initialized:
|
|
||||||
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
|
||||||
|
|
||||||
# Begin with device reset workflow
|
|
||||||
msg = proto.ResetDevice(display_random=display_random,
|
|
||||||
strength=strength,
|
|
||||||
language=language,
|
|
||||||
passphrase_protection=bool(passphrase_protection),
|
|
||||||
pin_protection=bool(pin_protection),
|
|
||||||
label=label)
|
|
||||||
|
|
||||||
resp = self.call(msg)
|
|
||||||
if not isinstance(resp, proto.EntropyRequest):
|
|
||||||
raise Exception("Invalid response, expected EntropyRequest")
|
|
||||||
|
|
||||||
external_entropy = self._get_local_entropy()
|
|
||||||
print "Computer generated entropy:", binascii.hexlify(external_entropy)
|
|
||||||
resp = self.call(proto.EntropyAck(entropy=external_entropy))
|
|
||||||
|
|
||||||
|
|
||||||
return isinstance(resp, proto.Success)
|
|
||||||
|
|
||||||
def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language):
|
|
||||||
if self.features.initialized:
|
|
||||||
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
|
||||||
|
|
||||||
resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin,
|
|
||||||
passphrase_protection=passphrase_protection,
|
|
||||||
language=language,
|
|
||||||
label=label))
|
|
||||||
self.init_device()
|
|
||||||
return isinstance(resp, proto.Success)
|
|
||||||
|
|
||||||
def load_device_by_xprv(self, xprv, pin, passphrase_protection, label):
|
|
||||||
if self.features.initialized:
|
|
||||||
raise Exception("Device is initialized already. Call wipe_device() and try again.")
|
|
||||||
|
|
||||||
if xprv[0:4] not in ('xprv', 'tprv'):
|
|
||||||
raise Exception("Unknown type of xprv")
|
|
||||||
|
|
||||||
if len(xprv) < 100 and len(xprv) > 112:
|
|
||||||
raise Exception("Invalid length of xprv")
|
|
||||||
|
|
||||||
node = types.HDNodeType()
|
|
||||||
data = tools.b58decode(xprv, None).encode('hex')
|
|
||||||
|
|
||||||
if data[90:92] != '00':
|
|
||||||
raise Exception("Contain invalid private key")
|
|
||||||
|
|
||||||
checksum = hashlib.sha256(hashlib.sha256(binascii.unhexlify(data[:156])).digest()).hexdigest()[:8]
|
|
||||||
if checksum != data[156:]:
|
|
||||||
raise Exception("Checksum doesn't match")
|
|
||||||
|
|
||||||
# version 0488ade4
|
|
||||||
# depth 00
|
|
||||||
# fingerprint 00000000
|
|
||||||
# child_num 00000000
|
|
||||||
# chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508
|
|
||||||
# privkey 00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35
|
|
||||||
# checksum e77e9d71
|
|
||||||
|
|
||||||
node.version = int(data[0:8], 16)
|
|
||||||
node.depth = int(data[8:10], 16)
|
|
||||||
node.fingerprint = int(data[10:18], 16)
|
|
||||||
node.child_num = int(data[18:26], 16)
|
|
||||||
node.chain_code = data[26:90].decode('hex')
|
|
||||||
node.private_key = data[92:156].decode('hex') # skip 0x00 indicating privkey
|
|
||||||
|
|
||||||
resp = self.call(proto.LoadDevice(node=node,
|
|
||||||
pin=pin,
|
|
||||||
passphrase_protection=passphrase_protection,
|
|
||||||
language='english',
|
|
||||||
label=label))
|
|
||||||
self.init_device()
|
|
||||||
return isinstance(resp, proto.Success)
|
|
||||||
|
|
||||||
def firmware_update(self, fp):
|
|
||||||
if self.features.bootloader_mode == False:
|
|
||||||
raise Exception("Device must be in bootloader mode")
|
|
||||||
|
|
||||||
resp = self.call(proto.FirmwareErase())
|
|
||||||
if isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
resp = self.call(proto.FirmwareUpload(payload=fp.read()))
|
|
||||||
if isinstance(resp, proto.Success):
|
|
||||||
return True
|
|
||||||
|
|
||||||
elif isinstance(resp, proto.Failure) and resp.code == types.Failure_FirmwareError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
raise Exception("Unexpected result " % resp)
|
|
||||||
|
|
||||||
# class TrezorDebugClient(TrezorClient):
|
|
||||||
|
@ -14,6 +14,9 @@ class DebugLink(object):
|
|||||||
self.pin_func = pin_func
|
self.pin_func = pin_func
|
||||||
self.button_func = button_func
|
self.button_func = button_func
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
def read_pin(self):
|
def read_pin(self):
|
||||||
self.transport.write(proto.DebugLinkGetState())
|
self.transport.write(proto.DebugLinkGetState())
|
||||||
obj = self.transport.read_blocking()
|
obj = self.transport.read_blocking()
|
||||||
|
Loading…
Reference in New Issue
Block a user