mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-18 05:28:40 +00:00
54f1599a5a
This clarifies the intent: the project is licenced under terms of LGPL version 3 only, but the standard headers cover only "3 or later", so we had to rewrite them. In the same step, we removed author information from individual files in favor of "SatoshiLabs and contributors", and include an AUTHORS file that lists the contributors. Apologies to those whose names are missing; please contact us if you wish to add your info to the AUTHORS file.
1150 lines
43 KiB
Python
1150 lines
43 KiB
Python
# This file is part of the Trezor project.
|
|
#
|
|
# Copyright (C) 2012-2018 SatoshiLabs and contributors
|
|
#
|
|
# This library is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser General Public License version 3
|
|
# as published by the Free Software Foundation.
|
|
#
|
|
# This library is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Lesser General Public License for more details.
|
|
#
|
|
# You should have received a copy of the License along with this library.
|
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
|
|
|
import functools
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
import binascii
|
|
import hashlib
|
|
import unicodedata
|
|
import getpass
|
|
import warnings
|
|
|
|
from mnemonic import Mnemonic
|
|
|
|
from . import messages as proto
|
|
from . import tools
|
|
from . import mapping
|
|
from . import nem
|
|
from . import protobuf
|
|
from . import stellar
|
|
from .debuglink import DebugLink
|
|
|
|
if sys.version_info.major < 3:
|
|
raise Exception("Trezorlib does not support Python 2 anymore.")
|
|
|
|
|
|
SCREENSHOT = False
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
# make a getch function
|
|
try:
|
|
import termios
|
|
import tty
|
|
# POSIX system. Create and return a getch that manipulates the tty.
|
|
# On Windows, termios will fail to import.
|
|
|
|
def getch():
|
|
fd = sys.stdin.fileno()
|
|
old_settings = termios.tcgetattr(fd)
|
|
try:
|
|
tty.setraw(fd)
|
|
ch = sys.stdin.read(1)
|
|
finally:
|
|
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
return ch
|
|
|
|
except ImportError:
|
|
# Windows system.
|
|
# Use msvcrt's getch function.
|
|
import msvcrt
|
|
|
|
def getch():
|
|
while True:
|
|
key = msvcrt.getch()
|
|
if key in (0x00, 0xe0):
|
|
# skip special keys: read the scancode and repeat
|
|
msvcrt.getch()
|
|
continue
|
|
return key.decode('latin1')
|
|
|
|
|
|
def get_buttonrequest_value(code):
|
|
# Converts integer code to its string representation of ButtonRequestType
|
|
return [k for k in dir(proto.ButtonRequestType) if getattr(proto.ButtonRequestType, k) == code][0]
|
|
|
|
|
|
class CallException(Exception):
|
|
pass
|
|
|
|
|
|
class PinException(CallException):
|
|
pass
|
|
|
|
|
|
class field:
|
|
# Decorator extracts single value from
|
|
# protobuf object. If the field is not
|
|
# present, raises an exception.
|
|
def __init__(self, field):
|
|
self.field = field
|
|
|
|
def __call__(self, f):
|
|
@functools.wraps(f)
|
|
def wrapped_f(*args, **kwargs):
|
|
ret = f(*args, **kwargs)
|
|
return getattr(ret, self.field)
|
|
return wrapped_f
|
|
|
|
|
|
class expect:
|
|
# Decorator checks if the method
|
|
# returned one of expected protobuf messages
|
|
# or raises an exception
|
|
def __init__(self, *expected):
|
|
self.expected = expected
|
|
|
|
def __call__(self, f):
|
|
@functools.wraps(f)
|
|
def wrapped_f(*args, **kwargs):
|
|
ret = f(*args, **kwargs)
|
|
if not isinstance(ret, self.expected):
|
|
raise RuntimeError("Got %s, expected %s" % (ret.__class__, self.expected))
|
|
return ret
|
|
return wrapped_f
|
|
|
|
|
|
def session(f):
|
|
# Decorator wraps a BaseClient method
|
|
# with session activation / deactivation
|
|
@functools.wraps(f)
|
|
def wrapped_f(*args, **kwargs):
|
|
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
|
|
client = args[0]
|
|
client.transport.session_begin()
|
|
try:
|
|
return f(*args, **kwargs)
|
|
finally:
|
|
client.transport.session_end()
|
|
return wrapped_f
|
|
|
|
|
|
def normalize_nfc(txt):
|
|
'''
|
|
Normalize message to NFC and return bytes suitable for protobuf.
|
|
This seems to be bitcoin-qt standard of doing things.
|
|
'''
|
|
if isinstance(txt, bytes):
|
|
txt = txt.decode('utf-8')
|
|
return unicodedata.normalize('NFC', txt).encode('utf-8')
|
|
|
|
|
|
class BaseClient(object):
|
|
# Implements very basic layer of sending raw protobuf
|
|
# messages to device and getting its response back.
|
|
def __init__(self, transport, **kwargs):
|
|
LOG.info("creating client instance for device: {}".format(transport.get_path()))
|
|
self.transport = transport
|
|
super(BaseClient, self).__init__() # *args, **kwargs)
|
|
|
|
def close(self):
|
|
pass
|
|
|
|
def cancel(self):
|
|
self.transport.write(proto.Cancel())
|
|
|
|
@session
|
|
def call_raw(self, msg):
|
|
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
|
|
self.transport.write(msg)
|
|
return self.transport.read()
|
|
|
|
@session
|
|
def call(self, msg):
|
|
resp = self.call_raw(msg)
|
|
handler_name = "callback_%s" % resp.__class__.__name__
|
|
handler = getattr(self, handler_name, None)
|
|
|
|
if handler is not None:
|
|
msg = handler(resp)
|
|
if msg is None:
|
|
raise ValueError("Callback %s must return protobuf message, not None" % handler)
|
|
resp = self.call(msg)
|
|
|
|
return resp
|
|
|
|
def callback_Failure(self, msg):
|
|
if msg.code in (proto.FailureType.PinInvalid,
|
|
proto.FailureType.PinCancelled, proto.FailureType.PinExpected):
|
|
raise PinException(msg.code, msg.message)
|
|
|
|
raise CallException(msg.code, msg.message)
|
|
|
|
def register_message(self, msg):
|
|
'''Allow application to register custom protobuf message type'''
|
|
mapping.register_message(msg)
|
|
|
|
|
|
class TextUIMixin(object):
|
|
# This class demonstrates easy test-based UI
|
|
# integration between the device and wallet.
|
|
# You can implement similar functionality
|
|
# by implementing your own GuiMixin with
|
|
# graphical widgets for every type of these callbacks.
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(TextUIMixin, self).__init__(*args, **kwargs)
|
|
|
|
@staticmethod
|
|
def print(text):
|
|
print(text, file=sys.stderr)
|
|
|
|
def callback_ButtonRequest(self, msg):
|
|
# log("Sending ButtonAck for %s " % get_buttonrequest_value(msg.code))
|
|
return proto.ButtonAck()
|
|
|
|
def callback_RecoveryMatrix(self, msg):
|
|
if self.recovery_matrix_first_pass:
|
|
self.recovery_matrix_first_pass = False
|
|
self.print("Use the numeric keypad to describe positions. For the word list use only left and right keys.")
|
|
self.print("Use backspace to correct an entry. The keypad layout is:")
|
|
self.print(" 7 8 9 7 | 9")
|
|
self.print(" 4 5 6 4 | 6")
|
|
self.print(" 1 2 3 1 | 3")
|
|
while True:
|
|
character = getch()
|
|
if character in ('\x03', '\x04'):
|
|
return proto.Cancel()
|
|
|
|
if character in ('\x08', '\x7f'):
|
|
return proto.WordAck(word='\x08')
|
|
|
|
# ignore middle column if only 6 keys requested.
|
|
if msg.type == proto.WordRequestType.Matrix6 and character in ('2', '5', '8'):
|
|
continue
|
|
|
|
if character.isdigit():
|
|
return proto.WordAck(word=character)
|
|
|
|
def callback_PinMatrixRequest(self, msg):
|
|
if msg.type == proto.PinMatrixRequestType.Current:
|
|
desc = 'current PIN'
|
|
elif msg.type == proto.PinMatrixRequestType.NewFirst:
|
|
desc = 'new PIN'
|
|
elif msg.type == proto.PinMatrixRequestType.NewSecond:
|
|
desc = 'new PIN again'
|
|
else:
|
|
desc = 'PIN'
|
|
|
|
self.print("Use the numeric keypad to describe number positions. The layout is:")
|
|
self.print(" 7 8 9")
|
|
self.print(" 4 5 6")
|
|
self.print(" 1 2 3")
|
|
self.print("Please enter %s: " % desc)
|
|
pin = getpass.getpass('')
|
|
if not pin.isdigit():
|
|
raise ValueError('Non-numerical PIN provided')
|
|
return proto.PinMatrixAck(pin=pin)
|
|
|
|
def callback_PassphraseRequest(self, msg):
|
|
if msg.on_device is True:
|
|
return proto.PassphraseAck()
|
|
|
|
if os.getenv("PASSPHRASE") is not None:
|
|
self.print("Passphrase required. Using PASSPHRASE environment variable.")
|
|
passphrase = Mnemonic.normalize_string(os.getenv("PASSPHRASE"))
|
|
return proto.PassphraseAck(passphrase=passphrase)
|
|
|
|
self.print("Passphrase required: ")
|
|
passphrase = getpass.getpass('')
|
|
self.print("Confirm your Passphrase: ")
|
|
if passphrase == getpass.getpass(''):
|
|
passphrase = Mnemonic.normalize_string(passphrase)
|
|
return proto.PassphraseAck(passphrase=passphrase)
|
|
else:
|
|
self.print("Passphrase did not match! ")
|
|
exit()
|
|
|
|
def callback_PassphraseStateRequest(self, msg):
|
|
return proto.PassphraseStateAck()
|
|
|
|
def callback_WordRequest(self, msg):
|
|
if msg.type in (proto.WordRequestType.Matrix9,
|
|
proto.WordRequestType.Matrix6):
|
|
return self.callback_RecoveryMatrix(msg)
|
|
self.print("Enter one word of mnemonic: ")
|
|
word = input()
|
|
if self.expand:
|
|
word = self.mnemonic_wordlist.expand_word(word)
|
|
return proto.WordAck(word=word)
|
|
|
|
|
|
class DebugLinkMixin(object):
|
|
# This class implements automatic responses
|
|
# and other functionality for unit tests
|
|
# for various callbacks, created in order
|
|
# to automatically pass unit tests.
|
|
#
|
|
# This mixing should be used only for purposes
|
|
# of unit testing, because it will fail to work
|
|
# without special DebugLink interface provided
|
|
# by the device.
|
|
DEBUG = LOG.getChild('debug_link').debug
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super(DebugLinkMixin, self).__init__(*args, **kwargs)
|
|
self.debug = None
|
|
self.in_with_statement = 0
|
|
self.button_wait = 0
|
|
self.screenshot_id = 0
|
|
|
|
# Always press Yes and provide correct pin
|
|
self.setup_debuglink(True, True)
|
|
|
|
# Do not expect any specific response from device
|
|
self.expected_responses = None
|
|
|
|
# Use blank passphrase
|
|
self.set_passphrase('')
|
|
|
|
def close(self):
|
|
super(DebugLinkMixin, self).close()
|
|
if self.debug:
|
|
self.debug.close()
|
|
|
|
def set_debuglink(self, debug_transport):
|
|
self.debug = DebugLink(debug_transport)
|
|
|
|
def set_buttonwait(self, secs):
|
|
self.button_wait = secs
|
|
|
|
def __enter__(self):
|
|
# For usage in with/expected_responses
|
|
self.in_with_statement += 1
|
|
return self
|
|
|
|
def __exit__(self, _type, value, traceback):
|
|
self.in_with_statement -= 1
|
|
|
|
if _type is not None:
|
|
# Another exception raised
|
|
return False
|
|
|
|
# return isinstance(value, TypeError)
|
|
# Evaluate missed responses in 'with' statement
|
|
if self.expected_responses is not None and len(self.expected_responses):
|
|
raise RuntimeError("Some of expected responses didn't come from device: %s" %
|
|
[repr(x) for x in self.expected_responses])
|
|
|
|
# Cleanup
|
|
self.expected_responses = None
|
|
return False
|
|
|
|
def set_expected_responses(self, expected):
|
|
if not self.in_with_statement:
|
|
raise RuntimeError("Must be called inside 'with' statement")
|
|
self.expected_responses = expected
|
|
|
|
def setup_debuglink(self, button, pin_correct):
|
|
self.button = button # True -> YES button, False -> NO button
|
|
self.pin_correct = pin_correct
|
|
|
|
def set_passphrase(self, passphrase):
|
|
self.passphrase = Mnemonic.normalize_string(passphrase)
|
|
|
|
def set_mnemonic(self, mnemonic):
|
|
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(' ')
|
|
|
|
def call_raw(self, msg):
|
|
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
|
|
|
|
if SCREENSHOT and self.debug:
|
|
from PIL import Image
|
|
layout = self.debug.read_layout()
|
|
im = Image.new("RGB", (128, 64))
|
|
pix = im.load()
|
|
for x in range(128):
|
|
for y in range(64):
|
|
rx, ry = 127 - x, 63 - y
|
|
if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
|
|
pix[x, y] = (255, 255, 255)
|
|
im.save('scr%05d.png' % self.screenshot_id)
|
|
self.screenshot_id += 1
|
|
|
|
resp = super(DebugLinkMixin, self).call_raw(msg)
|
|
self._check_request(resp)
|
|
return resp
|
|
|
|
def _check_request(self, msg):
|
|
__tracebackhide__ = True # pytest traceback hiding - this function won't appear in tracebacks
|
|
|
|
if self.expected_responses is not None:
|
|
try:
|
|
expected = self.expected_responses.pop(0)
|
|
except IndexError:
|
|
raise AssertionError(proto.FailureType.UnexpectedMessage,
|
|
"Got %s, but no message has been expected" % repr(msg))
|
|
|
|
if msg.__class__ != expected.__class__:
|
|
raise AssertionError(proto.FailureType.UnexpectedMessage,
|
|
"Expected %s, got %s" % (repr(expected), repr(msg)))
|
|
|
|
for field, value in expected.__dict__.items():
|
|
if value is None or value == []:
|
|
continue
|
|
if getattr(msg, field) != value:
|
|
raise AssertionError(proto.FailureType.UnexpectedMessage,
|
|
"Expected %s, got %s" % (repr(expected), repr(msg)))
|
|
|
|
def callback_ButtonRequest(self, msg):
|
|
self.DEBUG("ButtonRequest code: " + get_buttonrequest_value(msg.code))
|
|
|
|
self.DEBUG("Pressing button " + str(self.button))
|
|
if self.button_wait:
|
|
self.DEBUG("Waiting %d seconds " % self.button_wait)
|
|
time.sleep(self.button_wait)
|
|
self.debug.press_button(self.button)
|
|
return proto.ButtonAck()
|
|
|
|
def callback_PinMatrixRequest(self, msg):
|
|
if self.pin_correct:
|
|
pin = self.debug.read_pin_encoded()
|
|
else:
|
|
pin = '444222'
|
|
return proto.PinMatrixAck(pin=pin)
|
|
|
|
def callback_PassphraseRequest(self, msg):
|
|
self.DEBUG("Provided passphrase: '%s'" % self.passphrase)
|
|
return proto.PassphraseAck(passphrase=self.passphrase)
|
|
|
|
def callback_PassphraseStateRequest(self, msg):
|
|
return proto.PassphraseStateAck()
|
|
|
|
def callback_WordRequest(self, msg):
|
|
(word, pos) = self.debug.read_recovery_word()
|
|
if word != '':
|
|
return proto.WordAck(word=word)
|
|
if pos != 0:
|
|
return proto.WordAck(word=self.mnemonic[pos - 1])
|
|
|
|
raise RuntimeError("Unexpected call")
|
|
|
|
|
|
class ProtocolMixin(object):
|
|
VENDORS = ('bitcointrezor.com', 'trezor.io')
|
|
|
|
def __init__(self, state=None, *args, **kwargs):
|
|
super(ProtocolMixin, self).__init__(*args, **kwargs)
|
|
self.state = state
|
|
self.init_device()
|
|
self.tx_api = None
|
|
|
|
def set_tx_api(self, tx_api):
|
|
self.tx_api = tx_api
|
|
|
|
def init_device(self):
|
|
init_msg = proto.Initialize()
|
|
if self.state is not None:
|
|
init_msg.state = self.state
|
|
self.features = expect(proto.Features)(self.call)(init_msg)
|
|
if str(self.features.vendor) not in self.VENDORS:
|
|
raise RuntimeError("Unsupported device")
|
|
|
|
def _get_local_entropy(self):
|
|
return os.urandom(32)
|
|
|
|
@staticmethod
|
|
def _convert_prime(n: tools.Address) -> tools.Address:
|
|
# Convert minus signs to uint32 with flag
|
|
return [tools.H_(int(abs(x))) if x < 0 else x for x in n]
|
|
|
|
@staticmethod
|
|
def expand_path(n):
|
|
warnings.warn('expand_path is deprecated, use tools.parse_path', DeprecationWarning)
|
|
return tools.parse_path(n)
|
|
|
|
@expect(proto.PublicKey)
|
|
def get_public_node(self, n, ecdsa_curve_name=None, show_display=False, coin_name=None):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.GetPublicKey(address_n=n, ecdsa_curve_name=ecdsa_curve_name, show_display=show_display, coin_name=coin_name))
|
|
|
|
@field('address')
|
|
@expect(proto.Address)
|
|
def get_address(self, coin_name, n, show_display=False, multisig=None, script_type=proto.InputScriptType.SPENDADDRESS):
|
|
n = self._convert_prime(n)
|
|
if multisig:
|
|
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, multisig=multisig, script_type=script_type))
|
|
else:
|
|
return self.call(proto.GetAddress(address_n=n, coin_name=coin_name, show_display=show_display, script_type=script_type))
|
|
|
|
@field('address')
|
|
@expect(proto.EthereumAddress)
|
|
def ethereum_get_address(self, n, show_display=False, multisig=None):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.EthereumGetAddress(address_n=n, show_display=show_display))
|
|
|
|
@session
|
|
def ethereum_sign_tx(self, n, nonce, gas_price, gas_limit, to, value, data=None, chain_id=None, tx_type=None):
|
|
def int_to_big_endian(value):
|
|
return value.to_bytes((value.bit_length() + 7) // 8, 'big')
|
|
|
|
n = self._convert_prime(n)
|
|
|
|
msg = proto.EthereumSignTx(
|
|
address_n=n,
|
|
nonce=int_to_big_endian(nonce),
|
|
gas_price=int_to_big_endian(gas_price),
|
|
gas_limit=int_to_big_endian(gas_limit),
|
|
value=int_to_big_endian(value))
|
|
|
|
if to:
|
|
msg.to = to
|
|
|
|
if data:
|
|
msg.data_length = len(data)
|
|
data, chunk = data[1024:], data[:1024]
|
|
msg.data_initial_chunk = chunk
|
|
|
|
if chain_id:
|
|
msg.chain_id = chain_id
|
|
|
|
if tx_type is not None:
|
|
msg.tx_type = tx_type
|
|
|
|
response = self.call(msg)
|
|
|
|
while response.data_length is not None:
|
|
data_length = response.data_length
|
|
data, chunk = data[data_length:], data[:data_length]
|
|
response = self.call(proto.EthereumTxAck(data_chunk=chunk))
|
|
|
|
return response.signature_v, response.signature_r, response.signature_s
|
|
|
|
@expect(proto.EthereumMessageSignature)
|
|
def ethereum_sign_message(self, n, message):
|
|
n = self._convert_prime(n)
|
|
message = normalize_nfc(message)
|
|
return self.call(proto.EthereumSignMessage(address_n=n, message=message))
|
|
|
|
def ethereum_verify_message(self, address, signature, message):
|
|
message = normalize_nfc(message)
|
|
try:
|
|
resp = self.call(proto.EthereumVerifyMessage(address=address, signature=signature, message=message))
|
|
except CallException as e:
|
|
resp = e
|
|
if isinstance(resp, proto.Success):
|
|
return True
|
|
return False
|
|
|
|
#
|
|
# Lisk functions
|
|
#
|
|
|
|
@field('address')
|
|
@expect(proto.LiskAddress)
|
|
def lisk_get_address(self, n, show_display=False):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.LiskGetAddress(address_n=n, show_display=show_display))
|
|
|
|
@expect(proto.LiskPublicKey)
|
|
def lisk_get_public_key(self, n, show_display=False):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.LiskGetPublicKey(address_n=n, show_display=show_display))
|
|
|
|
@expect(proto.LiskMessageSignature)
|
|
def lisk_sign_message(self, n, message):
|
|
n = self._convert_prime(n)
|
|
message = normalize_nfc(message)
|
|
return self.call(proto.LiskSignMessage(address_n=n, message=message))
|
|
|
|
def lisk_verify_message(self, pubkey, signature, message):
|
|
message = normalize_nfc(message)
|
|
try:
|
|
resp = self.call(proto.LiskVerifyMessage(signature=signature, public_key=pubkey, message=message))
|
|
except CallException as e:
|
|
resp = e
|
|
return isinstance(resp, proto.Success)
|
|
|
|
@expect(proto.LiskSignedTx)
|
|
def lisk_sign_tx(self, n, transaction):
|
|
n = self._convert_prime(n)
|
|
|
|
def asset_to_proto(asset):
|
|
msg = proto.LiskTransactionAsset()
|
|
|
|
if "votes" in asset:
|
|
msg.votes = asset["votes"]
|
|
if "data" in asset:
|
|
msg.data = asset["data"]
|
|
if "signature" in asset:
|
|
msg.signature = proto.LiskSignatureType()
|
|
msg.signature.public_key = binascii.unhexlify(asset["signature"]["publicKey"])
|
|
if "delegate" in asset:
|
|
msg.delegate = proto.LiskDelegateType()
|
|
msg.delegate.username = asset["delegate"]["username"]
|
|
if "multisignature" in asset:
|
|
msg.multisignature = proto.LiskMultisignatureType()
|
|
msg.multisignature.min = asset["multisignature"]["min"]
|
|
msg.multisignature.life_time = asset["multisignature"]["lifetime"]
|
|
msg.multisignature.keys_group = asset["multisignature"]["keysgroup"]
|
|
return msg
|
|
|
|
msg = proto.LiskTransactionCommon()
|
|
|
|
msg.type = transaction["type"]
|
|
msg.fee = int(transaction["fee"]) # Lisk use strings for big numbers (javascript issue)
|
|
msg.amount = int(transaction["amount"]) # And we convert it back to number
|
|
msg.timestamp = transaction["timestamp"]
|
|
|
|
if "recipientId" in transaction:
|
|
msg.recipient_id = transaction["recipientId"]
|
|
if "senderPublicKey" in transaction:
|
|
msg.sender_public_key = binascii.unhexlify(transaction["senderPublicKey"])
|
|
if "requesterPublicKey" in transaction:
|
|
msg.requester_public_key = binascii.unhexlify(transaction["requesterPublicKey"])
|
|
if "signature" in transaction:
|
|
msg.signature = binascii.unhexlify(transaction["signature"])
|
|
|
|
msg.asset = asset_to_proto(transaction["asset"])
|
|
return self.call(proto.LiskSignTx(address_n=n, transaction=msg))
|
|
|
|
@field('entropy')
|
|
@expect(proto.Entropy)
|
|
def get_entropy(self, size):
|
|
return self.call(proto.GetEntropy(size=size))
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def ping(self, msg, button_protection=False, pin_protection=False, passphrase_protection=False):
|
|
msg = proto.Ping(message=msg,
|
|
button_protection=button_protection,
|
|
pin_protection=pin_protection,
|
|
passphrase_protection=passphrase_protection)
|
|
return self.call(msg)
|
|
|
|
def get_device_id(self):
|
|
return self.features.device_id
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def apply_settings(self, label=None, language=None, use_passphrase=None, homescreen=None, passphrase_source=None, auto_lock_delay_ms=None):
|
|
settings = proto.ApplySettings()
|
|
if label is not None:
|
|
settings.label = label
|
|
if language:
|
|
settings.language = language
|
|
if use_passphrase is not None:
|
|
settings.use_passphrase = use_passphrase
|
|
if homescreen is not None:
|
|
settings.homescreen = homescreen
|
|
if passphrase_source is not None:
|
|
settings.passphrase_source = passphrase_source
|
|
if auto_lock_delay_ms is not None:
|
|
settings.auto_lock_delay_ms = auto_lock_delay_ms
|
|
|
|
out = self.call(settings)
|
|
self.init_device() # Reload Features
|
|
return out
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def apply_flags(self, flags):
|
|
out = self.call(proto.ApplyFlags(flags=flags))
|
|
self.init_device() # Reload Features
|
|
return out
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def clear_session(self):
|
|
return self.call(proto.ClearSession())
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def change_pin(self, remove=False):
|
|
ret = self.call(proto.ChangePin(remove=remove))
|
|
self.init_device() # Re-read features
|
|
return ret
|
|
|
|
@expect(proto.MessageSignature)
|
|
def sign_message(self, coin_name, n, message, script_type=proto.InputScriptType.SPENDADDRESS):
|
|
n = self._convert_prime(n)
|
|
message = normalize_nfc(message)
|
|
return self.call(proto.SignMessage(coin_name=coin_name, address_n=n, message=message, script_type=script_type))
|
|
|
|
@expect(proto.SignedIdentity)
|
|
def sign_identity(self, identity, challenge_hidden, challenge_visual, ecdsa_curve_name=None):
|
|
return self.call(proto.SignIdentity(identity=identity, challenge_hidden=challenge_hidden, challenge_visual=challenge_visual, ecdsa_curve_name=ecdsa_curve_name))
|
|
|
|
@expect(proto.ECDHSessionKey)
|
|
def get_ecdh_session_key(self, identity, peer_public_key, ecdsa_curve_name=None):
|
|
return self.call(proto.GetECDHSessionKey(identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name))
|
|
|
|
@expect(proto.CosiCommitment)
|
|
def cosi_commit(self, n, data):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.CosiCommit(address_n=n, data=data))
|
|
|
|
@expect(proto.CosiSignature)
|
|
def cosi_sign(self, n, data, global_commitment, global_pubkey):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.CosiSign(address_n=n, data=data, global_commitment=global_commitment, global_pubkey=global_pubkey))
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def set_u2f_counter(self, u2f_counter):
|
|
ret = self.call(proto.SetU2FCounter(u2f_counter=u2f_counter))
|
|
return ret
|
|
|
|
@field("address")
|
|
@expect(proto.NEMAddress)
|
|
def nem_get_address(self, n, network, show_display=False):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.NEMGetAddress(address_n=n, network=network, show_display=show_display))
|
|
|
|
@expect(proto.NEMSignedTx)
|
|
def nem_sign_tx(self, n, transaction):
|
|
n = self._convert_prime(n)
|
|
try:
|
|
msg = nem.create_sign_tx(transaction)
|
|
except ValueError as e:
|
|
raise CallException(e.args)
|
|
|
|
assert msg.transaction is not None
|
|
msg.transaction.address_n = n
|
|
return self.call(msg)
|
|
|
|
def verify_message(self, coin_name, address, signature, message):
|
|
message = normalize_nfc(message)
|
|
try:
|
|
resp = self.call(proto.VerifyMessage(address=address, signature=signature, message=message, coin_name=coin_name))
|
|
except CallException as e:
|
|
resp = e
|
|
if isinstance(resp, proto.Success):
|
|
return True
|
|
return False
|
|
|
|
@expect(proto.EncryptedMessage)
|
|
def encrypt_message(self, pubkey, message, display_only, coin_name, n):
|
|
if coin_name and n:
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.EncryptMessage(pubkey=pubkey, message=message, display_only=display_only, coin_name=coin_name, address_n=n))
|
|
else:
|
|
return self.call(proto.EncryptMessage(pubkey=pubkey, message=message, display_only=display_only))
|
|
|
|
@expect(proto.DecryptedMessage)
|
|
def decrypt_message(self, n, nonce, message, msg_hmac):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.DecryptMessage(address_n=n, nonce=nonce, message=message, hmac=msg_hmac))
|
|
|
|
@field('value')
|
|
@expect(proto.CipheredKeyValue)
|
|
def encrypt_keyvalue(self, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.CipherKeyValue(address_n=n,
|
|
key=key,
|
|
value=value,
|
|
encrypt=True,
|
|
ask_on_encrypt=ask_on_encrypt,
|
|
ask_on_decrypt=ask_on_decrypt,
|
|
iv=iv))
|
|
|
|
@field('value')
|
|
@expect(proto.CipheredKeyValue)
|
|
def decrypt_keyvalue(self, n, key, value, ask_on_encrypt=True, ask_on_decrypt=True, iv=b''):
|
|
n = self._convert_prime(n)
|
|
return self.call(proto.CipherKeyValue(address_n=n,
|
|
key=key,
|
|
value=value,
|
|
encrypt=False,
|
|
ask_on_encrypt=ask_on_encrypt,
|
|
ask_on_decrypt=ask_on_decrypt,
|
|
iv=iv))
|
|
|
|
def _prepare_sign_tx(self, inputs, outputs):
|
|
tx = proto.TransactionType()
|
|
tx.inputs = inputs
|
|
tx.outputs = outputs
|
|
|
|
txes = {None: tx}
|
|
|
|
for inp in inputs:
|
|
if inp.prev_hash in txes:
|
|
continue
|
|
|
|
if inp.script_type in (proto.InputScriptType.SPENDP2SHWITNESS,
|
|
proto.InputScriptType.SPENDWITNESS):
|
|
continue
|
|
|
|
if not self.tx_api:
|
|
raise RuntimeError('TX_API not defined')
|
|
|
|
prev_tx = self.tx_api.get_tx(binascii.hexlify(inp.prev_hash).decode('utf-8'))
|
|
txes[inp.prev_hash] = prev_tx
|
|
|
|
return txes
|
|
|
|
@session
|
|
def sign_tx(self, coin_name, inputs, outputs, version=None, lock_time=None, expiry=None, overwintered=None, debug_processor=None):
|
|
|
|
# start = time.time()
|
|
txes = self._prepare_sign_tx(inputs, outputs)
|
|
|
|
# Prepare and send initial message
|
|
tx = proto.SignTx()
|
|
tx.inputs_count = len(inputs)
|
|
tx.outputs_count = len(outputs)
|
|
tx.coin_name = coin_name
|
|
if version is not None:
|
|
tx.version = version
|
|
if lock_time is not None:
|
|
tx.lock_time = lock_time
|
|
if expiry is not None:
|
|
tx.expiry = expiry
|
|
if overwintered is not None:
|
|
tx.overwintered = overwintered
|
|
res = self.call(tx)
|
|
|
|
# Prepare structure for signatures
|
|
signatures = [None] * len(inputs)
|
|
serialized_tx = b''
|
|
|
|
counter = 0
|
|
while True:
|
|
counter += 1
|
|
|
|
if isinstance(res, proto.Failure):
|
|
raise CallException("Signing failed")
|
|
|
|
if not isinstance(res, proto.TxRequest):
|
|
raise CallException("Unexpected message")
|
|
|
|
# If there's some part of signed transaction, let's add it
|
|
if res.serialized and res.serialized.serialized_tx:
|
|
# log("RECEIVED PART OF SERIALIZED TX (%d BYTES)" % len(res.serialized.serialized_tx))
|
|
serialized_tx += res.serialized.serialized_tx
|
|
|
|
if res.serialized and res.serialized.signature_index is not None:
|
|
if signatures[res.serialized.signature_index] is not None:
|
|
raise ValueError("Signature for index %d already filled" % res.serialized.signature_index)
|
|
signatures[res.serialized.signature_index] = res.serialized.signature
|
|
|
|
if res.request_type == proto.RequestType.TXFINISHED:
|
|
# Device didn't ask for more information, finish workflow
|
|
break
|
|
|
|
# Device asked for one more information, let's process it.
|
|
if not res.details.tx_hash:
|
|
current_tx = txes[None]
|
|
else:
|
|
current_tx = txes[bytes(res.details.tx_hash)]
|
|
|
|
if res.request_type == proto.RequestType.TXMETA:
|
|
msg = proto.TransactionType()
|
|
msg.version = current_tx.version
|
|
msg.lock_time = current_tx.lock_time
|
|
msg.inputs_cnt = len(current_tx.inputs)
|
|
if res.details.tx_hash:
|
|
msg.outputs_cnt = len(current_tx.bin_outputs)
|
|
else:
|
|
msg.outputs_cnt = len(current_tx.outputs)
|
|
msg.extra_data_len = len(current_tx.extra_data) if current_tx.extra_data else 0
|
|
res = self.call(proto.TxAck(tx=msg))
|
|
continue
|
|
|
|
elif res.request_type == proto.RequestType.TXINPUT:
|
|
msg = proto.TransactionType()
|
|
msg.inputs = [current_tx.inputs[res.details.request_index]]
|
|
if debug_processor is not None:
|
|
# msg needs to be deep copied so when it's modified
|
|
# the other messages stay intact
|
|
from copy import deepcopy
|
|
msg = deepcopy(msg)
|
|
# If debug_processor function is provided,
|
|
# pass thru it the request and prepared response.
|
|
# This is useful for tests, see test_msg_signtx
|
|
msg = debug_processor(res, msg)
|
|
|
|
res = self.call(proto.TxAck(tx=msg))
|
|
continue
|
|
|
|
elif res.request_type == proto.RequestType.TXOUTPUT:
|
|
msg = proto.TransactionType()
|
|
if res.details.tx_hash:
|
|
msg.bin_outputs = [current_tx.bin_outputs[res.details.request_index]]
|
|
else:
|
|
msg.outputs = [current_tx.outputs[res.details.request_index]]
|
|
|
|
if debug_processor is not None:
|
|
# msg needs to be deep copied so when it's modified
|
|
# the other messages stay intact
|
|
from copy import deepcopy
|
|
msg = deepcopy(msg)
|
|
# If debug_processor function is provided,
|
|
# pass thru it the request and prepared response.
|
|
# This is useful for tests, see test_msg_signtx
|
|
msg = debug_processor(res, msg)
|
|
|
|
res = self.call(proto.TxAck(tx=msg))
|
|
continue
|
|
|
|
elif res.request_type == proto.RequestType.TXEXTRADATA:
|
|
o, l = res.details.extra_data_offset, res.details.extra_data_len
|
|
msg = proto.TransactionType()
|
|
msg.extra_data = current_tx.extra_data[o:o + l]
|
|
res = self.call(proto.TxAck(tx=msg))
|
|
continue
|
|
|
|
if None in signatures:
|
|
raise RuntimeError("Some signatures are missing!")
|
|
|
|
# log("SIGNED IN %.03f SECONDS, CALLED %d MESSAGES, %d BYTES" %
|
|
# (time.time() - start, counter, len(serialized_tx)))
|
|
|
|
return (signatures, serialized_tx)
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def wipe_device(self):
|
|
ret = self.call(proto.WipeDevice())
|
|
self.init_device()
|
|
return ret
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def recovery_device(self, word_count, passphrase_protection, pin_protection, label, language, type=proto.RecoveryDeviceType.ScrambledWords, expand=False, dry_run=False):
|
|
if self.features.initialized and not dry_run:
|
|
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
|
|
|
|
if word_count not in (12, 18, 24):
|
|
raise ValueError("Invalid word count. Use 12/18/24")
|
|
|
|
self.recovery_matrix_first_pass = True
|
|
|
|
self.expand = expand
|
|
if self.expand:
|
|
# optimization to load the wordlist once, instead of for each recovery word
|
|
self.mnemonic_wordlist = Mnemonic('english')
|
|
|
|
res = self.call(proto.RecoveryDevice(
|
|
word_count=int(word_count),
|
|
passphrase_protection=bool(passphrase_protection),
|
|
pin_protection=bool(pin_protection),
|
|
label=label,
|
|
language=language,
|
|
enforce_wordlist=True,
|
|
type=type,
|
|
dry_run=dry_run))
|
|
|
|
self.init_device()
|
|
return res
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
@session
|
|
def reset_device(self, display_random, strength, passphrase_protection, pin_protection, label, language, u2f_counter=0, skip_backup=False):
|
|
if self.features.initialized:
|
|
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
|
|
|
|
# Begin with device reset workflow
|
|
msg = proto.ResetDevice(display_random=display_random,
|
|
strength=strength,
|
|
passphrase_protection=bool(passphrase_protection),
|
|
pin_protection=bool(pin_protection),
|
|
language=language,
|
|
label=label,
|
|
u2f_counter=u2f_counter,
|
|
skip_backup=bool(skip_backup))
|
|
|
|
resp = self.call(msg)
|
|
if not isinstance(resp, proto.EntropyRequest):
|
|
raise RuntimeError("Invalid response, expected EntropyRequest")
|
|
|
|
external_entropy = self._get_local_entropy()
|
|
LOG.debug("Computer generated entropy: " + binascii.hexlify(external_entropy).decode())
|
|
ret = self.call(proto.EntropyAck(entropy=external_entropy))
|
|
self.init_device()
|
|
return ret
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def backup_device(self):
|
|
ret = self.call(proto.BackupDevice())
|
|
return ret
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def load_device_by_mnemonic(self, mnemonic, pin, passphrase_protection, label, language='english', skip_checksum=False, expand=False):
|
|
# Convert mnemonic to UTF8 NKFD
|
|
mnemonic = Mnemonic.normalize_string(mnemonic)
|
|
|
|
# Convert mnemonic to ASCII stream
|
|
mnemonic = mnemonic.encode('utf-8')
|
|
|
|
m = Mnemonic('english')
|
|
|
|
if expand:
|
|
mnemonic = m.expand(mnemonic)
|
|
|
|
if not skip_checksum and not m.check(mnemonic):
|
|
raise ValueError("Invalid mnemonic checksum")
|
|
|
|
if self.features.initialized:
|
|
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
|
|
|
|
resp = self.call(proto.LoadDevice(mnemonic=mnemonic, pin=pin,
|
|
passphrase_protection=passphrase_protection,
|
|
language=language,
|
|
label=label,
|
|
skip_checksum=skip_checksum))
|
|
self.init_device()
|
|
return resp
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def load_device_by_xprv(self, xprv, pin, passphrase_protection, label, language):
|
|
if self.features.initialized:
|
|
raise RuntimeError("Device is initialized already. Call wipe_device() and try again.")
|
|
|
|
if xprv[0:4] not in ('xprv', 'tprv'):
|
|
raise ValueError("Unknown type of xprv")
|
|
|
|
if not 100 < len(xprv) < 112: # yes this is correct in Python
|
|
raise ValueError("Invalid length of xprv")
|
|
|
|
node = proto.HDNodeType()
|
|
data = binascii.hexlify(tools.b58decode(xprv, None))
|
|
|
|
if data[90:92] != b'00':
|
|
raise ValueError("Contain invalid private key")
|
|
|
|
checksum = binascii.hexlify(tools.btc_hash(binascii.unhexlify(data[:156]))[:4])
|
|
if checksum != data[156:]:
|
|
raise ValueError("Checksum doesn't match")
|
|
|
|
# version 0488ade4
|
|
# depth 00
|
|
# fingerprint 00000000
|
|
# child_num 00000000
|
|
# chaincode 873dff81c02f525623fd1fe5167eac3a55a049de3d314bb42ee227ffed37d508
|
|
# privkey 00e8f32e723decf4051aefac8e2c93c9c5b214313817cdb01a1494b917c8436b35
|
|
# checksum e77e9d71
|
|
|
|
node.depth = int(data[8:10], 16)
|
|
node.fingerprint = int(data[10:18], 16)
|
|
node.child_num = int(data[18:26], 16)
|
|
node.chain_code = binascii.unhexlify(data[26:90])
|
|
node.private_key = binascii.unhexlify(data[92:156]) # skip 0x00 indicating privkey
|
|
|
|
resp = self.call(proto.LoadDevice(node=node,
|
|
pin=pin,
|
|
passphrase_protection=passphrase_protection,
|
|
language=language,
|
|
label=label))
|
|
self.init_device()
|
|
return resp
|
|
|
|
@session
|
|
def firmware_update(self, fp):
|
|
if self.features.bootloader_mode is False:
|
|
raise RuntimeError("Device must be in bootloader mode")
|
|
|
|
data = fp.read()
|
|
|
|
resp = self.call(proto.FirmwareErase(length=len(data)))
|
|
if isinstance(resp, proto.Failure) and resp.code == proto.FailureType.FirmwareError:
|
|
return False
|
|
|
|
# TREZORv1 method
|
|
if isinstance(resp, proto.Success):
|
|
fingerprint = hashlib.sha256(data[256:]).hexdigest()
|
|
LOG.debug("Firmware fingerprint: " + fingerprint)
|
|
resp = self.call(proto.FirmwareUpload(payload=data))
|
|
if isinstance(resp, proto.Success):
|
|
return True
|
|
elif isinstance(resp, proto.Failure) and resp.code == proto.FailureType.FirmwareError:
|
|
return False
|
|
raise RuntimeError("Unexpected result %s" % resp)
|
|
|
|
# TREZORv2 method
|
|
if isinstance(resp, proto.FirmwareRequest):
|
|
import pyblake2
|
|
while True:
|
|
payload = data[resp.offset:resp.offset + resp.length]
|
|
digest = pyblake2.blake2s(payload).digest()
|
|
resp = self.call(proto.FirmwareUpload(payload=payload, hash=digest))
|
|
if isinstance(resp, proto.FirmwareRequest):
|
|
continue
|
|
elif isinstance(resp, proto.Success):
|
|
return True
|
|
elif isinstance(resp, proto.Failure) and resp.code == proto.FailureType.FirmwareError:
|
|
return False
|
|
raise RuntimeError("Unexpected result %s" % resp)
|
|
|
|
raise RuntimeError("Unexpected message %s" % resp)
|
|
|
|
@field('message')
|
|
@expect(proto.Success)
|
|
def self_test(self):
|
|
if self.features.bootloader_mode is False:
|
|
raise RuntimeError("Device must be in bootloader mode")
|
|
|
|
return self.call(proto.SelfTest(payload=b'\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC'))
|
|
|
|
@field('public_key')
|
|
@expect(proto.StellarPublicKey)
|
|
def stellar_get_public_key(self, address_n, show_display=False):
|
|
return self.call(proto.StellarGetPublicKey(address_n=address_n, show_display=show_display))
|
|
|
|
@field('address')
|
|
@expect(proto.StellarAddress)
|
|
def stellar_get_address(self, address_n, show_display=False):
|
|
return self.call(proto.StellarGetAddress(address_n=address_n, show_display=show_display))
|
|
|
|
def stellar_sign_transaction(self, tx, operations, address_n, network_passphrase=None):
|
|
# default networkPassphrase to the public network
|
|
if network_passphrase is None:
|
|
network_passphrase = "Public Global Stellar Network ; September 2015"
|
|
|
|
tx.network_passphrase = network_passphrase
|
|
tx.address_n = address_n
|
|
tx.num_operations = len(operations)
|
|
# Signing loop works as follows:
|
|
#
|
|
# 1. Start with tx (header information for the transaction) and operations (an array of operation protobuf messagess)
|
|
# 2. Send the tx header to the device
|
|
# 3. Receive a StellarTxOpRequest message
|
|
# 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message
|
|
# 5. The final message received will be StellarSignedTx which is returned from this method
|
|
resp = self.call(tx)
|
|
try:
|
|
while isinstance(resp, proto.StellarTxOpRequest):
|
|
resp = self.call(operations.pop(0))
|
|
except IndexError:
|
|
# pop from empty list
|
|
raise CallException("Stellar.UnexpectedEndOfOperations",
|
|
"Reached end of operations without a signature.") from None
|
|
|
|
if not isinstance(resp, proto.StellarSignedTx):
|
|
raise CallException(proto.FailureType.UnexpectedMessage, resp)
|
|
|
|
if operations:
|
|
raise CallException("Stellar.UnprocessedOperations",
|
|
"Received a signature before processing all operations.")
|
|
|
|
return resp
|
|
|
|
|
|
class TrezorClient(ProtocolMixin, TextUIMixin, BaseClient):
|
|
def __init__(self, transport, *args, **kwargs):
|
|
super().__init__(transport=transport, *args, **kwargs)
|
|
|
|
|
|
class TrezorClientDebugLink(ProtocolMixin, DebugLinkMixin, BaseClient):
|
|
def __init__(self, transport, *args, **kwargs):
|
|
super().__init__(transport=transport, *args, **kwargs)
|