1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-28 02:18:20 +00:00
trezor-firmware/trezorlib/client.py

577 lines
18 KiB
Python
Raw Normal View History

# This file is part of the Trezor project.
2016-11-25 21:53:55 +00:00
#
# Copyright (C) 2012-2018 SatoshiLabs and contributors
2016-11-25 21:53:55 +00:00
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
2016-11-25 21:53:55 +00:00
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
2016-11-25 21:53:55 +00:00
2018-05-11 13:24:24 +00:00
import functools
2018-08-13 16:21:24 +00:00
import getpass
2018-05-11 13:24:24 +00:00
import logging
import os
import sys
2014-04-09 19:42:30 +00:00
import time
import warnings
2014-02-21 00:48:11 +00:00
from mnemonic import Mnemonic
2018-08-13 16:21:24 +00:00
from . import (
btc,
cosi,
debuglink,
device,
ethereum,
firmware,
lisk,
mapping,
messages as proto,
misc,
nem,
stellar,
tools,
)
2017-06-23 19:31:42 +00:00
if sys.version_info.major < 3:
2018-02-27 15:30:32 +00:00
raise Exception("Trezorlib does not support Python 2 anymore.")
SCREENSHOT = False
2018-05-11 13:24:24 +00:00
LOG = logging.getLogger(__name__)
# make a getch function
try:
import termios
import tty
2018-08-13 16:21:24 +00:00
# POSIX system. Create and return a getch that manipulates the tty.
# On Windows, termios will fail to import.
2017-06-23 19:31:42 +00:00
def getch():
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(fd)
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch
2018-08-13 16:21:24 +00:00
except ImportError:
# Windows system.
# Use msvcrt's getch function.
import msvcrt
def getch():
while True:
key = msvcrt.getch()
if key in (0x00, 0xe0):
# skip special keys: read the scancode and repeat
msvcrt.getch()
continue
2018-09-06 14:21:15 +00:00
return key.decode()
2017-06-23 19:31:42 +00:00
2014-02-13 15:46:21 +00:00
def get_buttonrequest_value(code):
# Converts integer code to its string representation of ButtonRequestType
2018-08-13 16:21:24 +00:00
return [
k
for k in dir(proto.ButtonRequestType)
if getattr(proto.ButtonRequestType, k) == code
][0]
2017-06-23 19:31:42 +00:00
2014-02-02 17:27:44 +00:00
class PinException(tools.CallException):
pass
class MovedTo:
"""Deprecation redirector for methods that were formerly part of TrezorClient"""
2018-08-13 16:21:24 +00:00
def __init__(self, where):
self.where = where
2018-08-13 16:21:24 +00:00
self.name = where.__module__ + "." + where.__name__
def _deprecated_redirect(self, client, *args, **kwargs):
"""Redirector for a deprecated method on TrezorClient"""
2018-08-13 16:21:24 +00:00
warnings.warn(
"Function has been moved to %s" % self.name,
DeprecationWarning,
stacklevel=2,
)
return self.where(client, *args, **kwargs)
2017-06-23 19:31:42 +00:00
def __get__(self, instance, cls):
if instance is None:
return self._deprecated_redirect
else:
return functools.partial(self._deprecated_redirect, instance)
2017-06-23 19:31:42 +00:00
2014-02-13 15:46:21 +00:00
class BaseClient(object):
# Implements very basic layer of sending raw protobuf
# messages to device and getting its response back.
2017-09-04 11:36:08 +00:00
def __init__(self, transport, **kwargs):
LOG.info("creating client instance for device: {}".format(transport.get_path()))
2017-09-04 11:36:08 +00:00
self.transport = transport
super(BaseClient, self).__init__() # *args, **kwargs)
def close(self):
pass
2016-02-10 15:46:58 +00:00
def cancel(self):
2017-09-04 11:36:08 +00:00
self.transport.write(proto.Cancel())
2016-02-10 15:46:58 +00:00
@tools.session
def call_raw(self, msg):
2018-08-13 16:21:24 +00:00
__tracebackhide__ = True # for pytest # pylint: disable=W0612
2017-09-04 11:36:08 +00:00
self.transport.write(msg)
return self.transport.read()
@tools.session
def call(self, msg):
resp = self.call_raw(msg)
handler_name = "callback_%s" % resp.__class__.__name__
handler = getattr(self, handler_name, None)
2017-06-23 19:31:42 +00:00
if handler is not None:
msg = handler(resp)
2017-06-23 19:31:42 +00:00
if msg is None:
2018-08-13 16:21:24 +00:00
raise ValueError(
"Callback %s must return protobuf message, not None" % handler
)
resp = self.call(msg)
2014-02-13 15:46:21 +00:00
return resp
def callback_Failure(self, msg):
2018-08-13 16:21:24 +00:00
if msg.code in (
proto.FailureType.PinInvalid,
proto.FailureType.PinCancelled,
proto.FailureType.PinExpected,
):
2014-02-13 15:46:21 +00:00
raise PinException(msg.code, msg.message)
raise tools.CallException(msg.code, msg.message)
2014-02-13 15:46:21 +00:00
def register_message(self, msg):
2018-08-13 16:21:24 +00:00
"""Allow application to register custom protobuf message type"""
mapping.register_message(msg)
2017-06-23 19:31:42 +00:00
2014-02-13 15:46:21 +00:00
class TextUIMixin(object):
# This class demonstrates easy test-based UI
# integration between the device and wallet.
# You can implement similar functionality
# by implementing your own GuiMixin with
# graphical widgets for every type of these callbacks.
def __init__(self, *args, **kwargs):
super(TextUIMixin, self).__init__(*args, **kwargs)
2018-05-11 13:24:24 +00:00
@staticmethod
def print(text):
print(text, file=sys.stderr)
2014-02-13 15:46:21 +00:00
def callback_ButtonRequest(self, msg):
# log("Sending ButtonAck for %s " % get_buttonrequest_value(msg.code))
2014-02-13 15:46:21 +00:00
return proto.ButtonAck()
def callback_RecoveryMatrix(self, msg):
if self.recovery_matrix_first_pass:
self.recovery_matrix_first_pass = False
2018-08-13 16:21:24 +00:00
self.print(
"Use the numeric keypad to describe positions. For the word list use only left and right keys."
)
2018-05-11 13:24:24 +00:00
self.print("Use backspace to correct an entry. The keypad layout is:")
self.print(" 7 8 9 7 | 9")
self.print(" 4 5 6 4 | 6")
self.print(" 1 2 3 1 | 3")
while True:
character = getch()
2018-08-13 16:21:24 +00:00
if character in ("\x03", "\x04"):
return proto.Cancel()
2018-08-13 16:21:24 +00:00
if character in ("\x08", "\x7f"):
return proto.WordAck(word="\x08")
# ignore middle column if only 6 keys requested.
2018-08-13 16:21:24 +00:00
if msg.type == proto.WordRequestType.Matrix6 and character in (
"2",
"5",
"8",
):
continue
if character.isdigit():
return proto.WordAck(word=character)
2014-02-13 15:46:21 +00:00
def callback_PinMatrixRequest(self, msg):
if msg.type == proto.PinMatrixRequestType.Current:
2018-08-13 16:21:24 +00:00
desc = "current PIN"
elif msg.type == proto.PinMatrixRequestType.NewFirst:
2018-08-13 16:21:24 +00:00
desc = "new PIN"
elif msg.type == proto.PinMatrixRequestType.NewSecond:
2018-08-13 16:21:24 +00:00
desc = "new PIN again"
2014-03-28 15:26:48 +00:00
else:
2018-08-13 16:21:24 +00:00
desc = "PIN"
2018-08-13 16:21:24 +00:00
self.print(
"Use the numeric keypad to describe number positions. The layout is:"
)
2018-05-11 13:24:24 +00:00
self.print(" 7 8 9")
self.print(" 4 5 6")
self.print(" 1 2 3")
self.print("Please enter %s: " % desc)
2018-08-13 16:21:24 +00:00
pin = getpass.getpass("")
if not pin.isdigit():
2018-08-13 16:21:24 +00:00
raise ValueError("Non-numerical PIN provided")
2014-02-13 15:46:21 +00:00
return proto.PinMatrixAck(pin=pin)
def callback_PassphraseRequest(self, msg):
if msg.on_device is True:
return proto.PassphraseAck()
if os.getenv("PASSPHRASE") is not None:
2018-05-11 13:24:24 +00:00
self.print("Passphrase required. Using PASSPHRASE environment variable.")
passphrase = Mnemonic.normalize_string(os.getenv("PASSPHRASE"))
return proto.PassphraseAck(passphrase=passphrase)
2018-05-11 13:24:24 +00:00
self.print("Passphrase required: ")
2018-08-13 16:21:24 +00:00
passphrase = getpass.getpass("")
2018-05-11 13:24:24 +00:00
self.print("Confirm your Passphrase: ")
2018-08-13 16:21:24 +00:00
if passphrase == getpass.getpass(""):
passphrase = Mnemonic.normalize_string(passphrase)
return proto.PassphraseAck(passphrase=passphrase)
else:
2018-05-11 13:24:24 +00:00
self.print("Passphrase did not match! ")
exit()
2014-02-13 15:46:21 +00:00
def callback_PassphraseStateRequest(self, msg):
return proto.PassphraseStateAck()
2014-02-13 15:46:21 +00:00
def callback_WordRequest(self, msg):
2018-08-13 16:21:24 +00:00
if msg.type in (proto.WordRequestType.Matrix9, proto.WordRequestType.Matrix6):
return self.callback_RecoveryMatrix(msg)
2018-05-11 13:24:24 +00:00
self.print("Enter one word of mnemonic: ")
2017-06-23 19:31:42 +00:00
word = input()
if self.expand:
word = self.mnemonic_wordlist.expand_word(word)
2014-02-13 15:46:21 +00:00
return proto.WordAck(word=word)
2017-06-23 19:31:42 +00:00
2014-02-13 15:46:21 +00:00
class DebugLinkMixin(object):
# This class implements automatic responses
# and other functionality for unit tests
# for various callbacks, created in order
# to automatically pass unit tests.
#
# This mixing should be used only for purposes
# of unit testing, because it will fail to work
# without special DebugLink interface provided
# by the device.
2018-08-13 16:21:24 +00:00
DEBUG = LOG.getChild("debug_link").debug
2014-02-13 15:46:21 +00:00
def __init__(self, *args, **kwargs):
super(DebugLinkMixin, self).__init__(*args, **kwargs)
self.debug = None
self.in_with_statement = 0
2014-12-10 14:26:18 +00:00
self.button_wait = 0
self.screenshot_id = 0
2014-02-13 15:46:21 +00:00
# Always press Yes and provide correct pin
self.setup_debuglink(True, True)
2016-01-12 23:17:38 +00:00
# Do not expect any specific response from device
self.expected_responses = None
# Use blank passphrase
2018-08-13 16:21:24 +00:00
self.set_passphrase("")
2014-02-13 15:46:21 +00:00
def close(self):
super(DebugLinkMixin, self).close()
if self.debug:
self.debug.close()
def set_debuglink(self, debug_transport):
self.debug = debuglink.DebugLink(debug_transport)
2014-02-13 15:46:21 +00:00
2014-12-10 14:26:18 +00:00
def set_buttonwait(self, secs):
self.button_wait = secs
def __enter__(self):
# For usage in with/expected_responses
self.in_with_statement += 1
return self
2014-02-21 20:00:56 +00:00
def __exit__(self, _type, value, traceback):
self.in_with_statement -= 1
2017-06-23 19:31:42 +00:00
if _type is not None:
2014-02-21 20:00:56 +00:00
# Another exception raised
return False
# return isinstance(value, TypeError)
# Evaluate missed responses in 'with' statement
2017-06-23 19:31:42 +00:00
if self.expected_responses is not None and len(self.expected_responses):
2018-08-13 16:21:24 +00:00
raise RuntimeError(
"Some of expected responses didn't come from device: %s"
% [repr(x) for x in self.expected_responses]
)
# Cleanup
self.expected_responses = None
return False
def set_expected_responses(self, expected):
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected
2014-02-13 15:46:21 +00:00
def setup_debuglink(self, button, pin_correct):
self.button = button # True -> YES button, False -> NO button
self.pin_correct = pin_correct
def set_passphrase(self, passphrase):
self.passphrase = Mnemonic.normalize_string(passphrase)
2014-02-21 20:00:56 +00:00
def set_mnemonic(self, mnemonic):
2018-08-13 16:21:24 +00:00
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def call_raw(self, msg):
2018-08-13 16:21:24 +00:00
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if SCREENSHOT and self.debug:
from PIL import Image
2018-08-13 16:21:24 +00:00
layout = self.debug.read_layout()
im = Image.new("RGB", (128, 64))
pix = im.load()
for x in range(128):
for y in range(64):
rx, ry = 127 - x, 63 - y
if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
pix[x, y] = (255, 255, 255)
2018-08-13 16:21:24 +00:00
im.save("scr%05d.png" % self.screenshot_id)
self.screenshot_id += 1
resp = super(DebugLinkMixin, self).call_raw(msg)
self._check_request(resp)
return resp
2016-01-12 23:17:38 +00:00
def _check_request(self, msg):
2018-08-13 16:21:24 +00:00
__tracebackhide__ = True # for pytest # pylint: disable=W0612
2017-06-23 19:31:42 +00:00
if self.expected_responses is not None:
2014-02-13 15:46:21 +00:00
try:
expected = self.expected_responses.pop(0)
2014-02-13 15:46:21 +00:00
except IndexError:
2018-08-13 16:21:24 +00:00
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Got %s, but no message has been expected" % repr(msg),
)
2014-02-13 15:46:21 +00:00
if msg.__class__ != expected.__class__:
2018-08-13 16:21:24 +00:00
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
2018-08-13 16:21:24 +00:00
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
2016-01-12 23:17:38 +00:00
def callback_ButtonRequest(self, msg):
2018-05-11 13:24:24 +00:00
self.DEBUG("ButtonRequest code: " + get_buttonrequest_value(msg.code))
2014-02-13 15:46:21 +00:00
2018-05-11 13:24:24 +00:00
self.DEBUG("Pressing button " + str(self.button))
2014-12-10 14:26:18 +00:00
if self.button_wait:
2018-05-11 13:24:24 +00:00
self.DEBUG("Waiting %d seconds " % self.button_wait)
2014-12-10 14:26:18 +00:00
time.sleep(self.button_wait)
2014-02-13 15:46:21 +00:00
self.debug.press_button(self.button)
return proto.ButtonAck()
def callback_PinMatrixRequest(self, msg):
if self.pin_correct:
pin = self.debug.read_pin_encoded()
else:
2018-08-13 16:21:24 +00:00
pin = "444222"
2014-02-13 15:46:21 +00:00
return proto.PinMatrixAck(pin=pin)
def callback_PassphraseRequest(self, msg):
2018-05-11 13:24:24 +00:00
self.DEBUG("Provided passphrase: '%s'" % self.passphrase)
return proto.PassphraseAck(passphrase=self.passphrase)
def callback_PassphraseStateRequest(self, msg):
return proto.PassphraseStateAck()
2014-02-13 15:46:21 +00:00
def callback_WordRequest(self, msg):
(word, pos) = self.debug.read_recovery_word()
2018-08-13 16:21:24 +00:00
if word != "":
2014-02-21 20:00:56 +00:00
return proto.WordAck(word=word)
if pos != 0:
return proto.WordAck(word=self.mnemonic[pos - 1])
raise RuntimeError("Unexpected call")
2014-02-13 15:46:21 +00:00
2017-06-23 19:31:42 +00:00
2014-02-13 15:46:21 +00:00
class ProtocolMixin(object):
2018-08-13 16:21:24 +00:00
VENDORS = ("bitcointrezor.com", "trezor.io")
2014-02-13 15:46:21 +00:00
def __init__(self, state=None, *args, **kwargs):
super(ProtocolMixin, self).__init__(*args, **kwargs)
self.state = state
self.init_device()
2014-03-28 18:47:53 +00:00
self.tx_api = None
2014-02-13 15:46:21 +00:00
2014-03-28 18:47:53 +00:00
def set_tx_api(self, tx_api):
self.tx_api = tx_api
2014-02-13 15:46:21 +00:00
def init_device(self):
init_msg = proto.Initialize()
if self.state is not None:
init_msg.state = self.state
self.features = tools.expect(proto.Features)(self.call)(init_msg)
2015-01-28 04:31:30 +00:00
if str(self.features.vendor) not in self.VENDORS:
raise RuntimeError("Unsupported device")
2016-02-10 14:53:14 +00:00
@staticmethod
def expand_path(n):
2018-08-13 16:21:24 +00:00
warnings.warn(
"expand_path is deprecated, use tools.parse_path",
DeprecationWarning,
stacklevel=2,
)
return tools.parse_path(n)
@tools.expect(proto.Success, field="message")
2018-08-13 16:21:24 +00:00
def ping(
self,
msg,
button_protection=False,
pin_protection=False,
passphrase_protection=False,
):
msg = proto.Ping(
message=msg,
button_protection=button_protection,
pin_protection=pin_protection,
passphrase_protection=passphrase_protection,
)
2014-02-13 15:46:21 +00:00
return self.call(msg)
2013-10-08 18:33:39 +00:00
def get_device_id(self):
return self.features.device_id
def _prepare_sign_tx(self, inputs, outputs):
tx = proto.TransactionType()
tx.inputs = inputs
tx.outputs = outputs
2014-04-09 19:42:30 +00:00
txes = {None: tx}
2014-04-09 19:42:30 +00:00
for inp in inputs:
if inp.prev_hash in txes:
2014-04-09 19:42:30 +00:00
continue
2018-08-13 16:21:24 +00:00
if inp.script_type in (
proto.InputScriptType.SPENDP2SHWITNESS,
proto.InputScriptType.SPENDWITNESS,
):
continue
if not self.tx_api:
2018-08-13 16:21:24 +00:00
raise RuntimeError("TX_API not defined")
prev_tx = self.tx_api.get_tx(inp.prev_hash.hex())
txes[inp.prev_hash] = prev_tx
2014-04-09 19:42:30 +00:00
return txes
2014-02-13 15:46:21 +00:00
@tools.expect(proto.Success, field="message")
def clear_session(self):
return self.call(proto.ClearSession())
# Device functionality
wipe_device = MovedTo(device.wipe)
recovery_device = MovedTo(device.recover)
reset_device = MovedTo(device.reset)
backup_device = MovedTo(device.backup)
# debugging
load_device_by_mnemonic = MovedTo(debuglink.load_device_by_mnemonic)
load_device_by_xprv = MovedTo(debuglink.load_device_by_xprv)
self_test = MovedTo(debuglink.self_test)
set_u2f_counter = MovedTo(device.set_u2f_counter)
apply_settings = MovedTo(device.apply_settings)
apply_flags = MovedTo(device.apply_flags)
change_pin = MovedTo(device.change_pin)
# Firmware functionality
firmware_update = MovedTo(firmware.update)
# BTC-like functionality
get_public_node = MovedTo(btc.get_public_node)
get_address = MovedTo(btc.get_address)
sign_tx = MovedTo(btc.sign_tx)
sign_message = MovedTo(btc.sign_message)
verify_message = MovedTo(btc.verify_message)
# CoSi functionality
cosi_commit = MovedTo(cosi.commit)
cosi_sign = MovedTo(cosi.sign)
# Ethereum functionality
ethereum_get_address = MovedTo(ethereum.get_address)
ethereum_sign_tx = MovedTo(ethereum.sign_tx)
ethereum_sign_message = MovedTo(ethereum.sign_message)
ethereum_verify_message = MovedTo(ethereum.verify_message)
# Lisk functionality
lisk_get_address = MovedTo(lisk.get_address)
lisk_get_public_key = MovedTo(lisk.get_public_key)
lisk_sign_message = MovedTo(lisk.sign_message)
lisk_verify_message = MovedTo(lisk.verify_message)
lisk_sign_tx = MovedTo(lisk.sign_tx)
# NEM functionality
nem_get_address = MovedTo(nem.get_address)
nem_sign_tx = MovedTo(nem.sign_tx)
# Stellar functionality
stellar_get_address = MovedTo(stellar.get_address)
stellar_sign_transaction = MovedTo(stellar.sign_tx)
2018-06-13 17:35:01 +00:00
# Miscellaneous cryptographic functionality
get_entropy = MovedTo(misc.get_entropy)
sign_identity = MovedTo(misc.sign_identity)
get_ecdh_session_key = MovedTo(misc.get_ecdh_session_key)
encrypt_keyvalue = MovedTo(misc.encrypt_keyvalue)
decrypt_keyvalue = MovedTo(misc.decrypt_keyvalue)
2018-04-04 01:50:22 +00:00
2017-06-23 19:31:42 +00:00
class TrezorClient(ProtocolMixin, TextUIMixin, BaseClient):
def __init__(self, transport, *args, **kwargs):
super().__init__(transport=transport, *args, **kwargs)
2014-02-13 15:46:21 +00:00
2017-06-23 19:31:42 +00:00
class TrezorClientDebugLink(ProtocolMixin, DebugLinkMixin, BaseClient):
def __init__(self, transport, *args, **kwargs):
super().__init__(transport=transport, *args, **kwargs)