mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 07:28:10 +00:00
trezorlib: get rid of TextUIMixin
This also moves DebugLinkMixin to debuglink.py and converts the mixin to a subclass of TrezorClient (which is finally becoming a reasonable-looking class). This takes advantage of the new UI protocol and is ready for further improvements, namely, queuing input for tests that require swipes. The ui.py module contains a Click-based implementation of the UI protocol. Use of callback_* methods has been limited and will probably be cleaned up further (The contract has changed so we'll try to make third party code fail noisily. It is unclear whether a backwards compatible approach will be possible). Furthermore, device.recovery() now takes a callback as an argument. This way we can get rid of WordRequest callbacks, which are only used in the recovery flow.
This commit is contained in:
parent
6d9157c4a5
commit
06927e003e
16
trezorctl
16
trezorctl
@ -49,6 +49,7 @@ from trezorlib import (
|
||||
stellar,
|
||||
tezos,
|
||||
tools,
|
||||
ui,
|
||||
)
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport import enumerate_devices, get_transport
|
||||
@ -120,7 +121,7 @@ def cli(ctx, path, verbose, is_json):
|
||||
if path is not None:
|
||||
click.echo("Using path: {}".format(path))
|
||||
sys.exit(1)
|
||||
return TrezorClient(transport=device)
|
||||
return TrezorClient(transport=device, ui=ui.ClickUI)
|
||||
|
||||
ctx.obj = get_device
|
||||
|
||||
@ -214,6 +215,7 @@ def get_features(connect):
|
||||
@click.option("-r", "--remove", is_flag=True)
|
||||
@click.pass_obj
|
||||
def change_pin(connect, remove):
|
||||
click.echo(ui.PIN_MATRIX_DESCRIPTION)
|
||||
return device.change_pin(connect(), remove)
|
||||
|
||||
|
||||
@ -326,7 +328,7 @@ def wipe_device(connect, bootloader):
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo(
|
||||
"Wiping user data and firmware! Please confirm the action on your device ..."
|
||||
"Wiping user data and firmware!"
|
||||
)
|
||||
else:
|
||||
if client.features.bootloader_mode:
|
||||
@ -339,7 +341,7 @@ def wipe_device(connect, bootloader):
|
||||
click.echo("Aborting.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo("Wiping user data! Please confirm the action on your device ...")
|
||||
click.echo("Wiping user data!")
|
||||
|
||||
try:
|
||||
return device.wipe(connect())
|
||||
@ -417,6 +419,12 @@ def recovery_device(
|
||||
rec_type,
|
||||
dry_run,
|
||||
):
|
||||
if rec_type == proto.RecoveryDeviceType.ScrambledWords:
|
||||
input_callback = ui.mnemonic_words(expand)
|
||||
else:
|
||||
input_callback = ui.matrix_words
|
||||
click.echo(ui.RECOVERY_MATRIX_DESCRIPTION)
|
||||
|
||||
return device.recover(
|
||||
connect(),
|
||||
int(words),
|
||||
@ -424,8 +432,8 @@ def recovery_device(
|
||||
pin_protection,
|
||||
label,
|
||||
"english",
|
||||
input_callback,
|
||||
rec_type,
|
||||
expand,
|
||||
dry_run,
|
||||
)
|
||||
|
||||
|
@ -30,6 +30,7 @@ from . import (
|
||||
debuglink,
|
||||
device,
|
||||
ethereum,
|
||||
exceptions,
|
||||
firmware,
|
||||
lisk,
|
||||
mapping,
|
||||
@ -47,38 +48,7 @@ if sys.version_info.major < 3:
|
||||
SCREENSHOT = False
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# make a getch function
|
||||
try:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
# POSIX system. Create and return a getch that manipulates the tty.
|
||||
# On Windows, termios will fail to import.
|
||||
|
||||
def getch():
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setraw(fd)
|
||||
ch = sys.stdin.read(1)
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
return ch
|
||||
|
||||
|
||||
except ImportError:
|
||||
# Windows system.
|
||||
# Use msvcrt's getch function.
|
||||
import msvcrt
|
||||
|
||||
def getch():
|
||||
while True:
|
||||
key = msvcrt.getch()
|
||||
if key in (0x00, 0xE0):
|
||||
# skip special keys: read the scancode and repeat
|
||||
msvcrt.getch()
|
||||
continue
|
||||
return key.decode()
|
||||
PinException = exceptions.PinException
|
||||
|
||||
|
||||
def get_buttonrequest_value(code):
|
||||
@ -90,10 +60,6 @@ def get_buttonrequest_value(code):
|
||||
][0]
|
||||
|
||||
|
||||
class PinException(tools.CallException):
|
||||
pass
|
||||
|
||||
|
||||
class MovedTo:
|
||||
"""Deprecation redirector for methods that were formerly part of TrezorClient"""
|
||||
|
||||
@ -120,9 +86,10 @@ class MovedTo:
|
||||
class BaseClient(object):
|
||||
# Implements very basic layer of sending raw protobuf
|
||||
# messages to device and getting its response back.
|
||||
def __init__(self, transport, **kwargs):
|
||||
def __init__(self, transport, ui, **kwargs):
|
||||
LOG.info("creating client instance for device: {}".format(transport.get_path()))
|
||||
self.transport = transport
|
||||
self.ui = ui
|
||||
super(BaseClient, self).__init__() # *args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
@ -137,298 +104,60 @@ class BaseClient(object):
|
||||
self.transport.write(msg)
|
||||
return self.transport.read()
|
||||
|
||||
@tools.session
|
||||
def call(self, msg):
|
||||
resp = self.call_raw(msg)
|
||||
handler_name = "callback_%s" % resp.__class__.__name__
|
||||
handler = getattr(self, handler_name, None)
|
||||
def callback_PinMatrixRequest(self, msg):
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
if not pin.isdigit():
|
||||
raise ValueError("Non-numeric PIN provided")
|
||||
|
||||
if handler is not None:
|
||||
msg = handler(resp)
|
||||
if msg is None:
|
||||
raise ValueError(
|
||||
"Callback %s must return protobuf message, not None" % handler
|
||||
)
|
||||
resp = self.call(msg)
|
||||
|
||||
return resp
|
||||
|
||||
def callback_Failure(self, msg):
|
||||
if msg.code in (
|
||||
resp = self.call_raw(proto.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, proto.Failure) and resp.code in (
|
||||
proto.FailureType.PinInvalid,
|
||||
proto.FailureType.PinCancelled,
|
||||
proto.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(msg.code, msg.message)
|
||||
raise exceptions.PinException(msg.code, msg.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
raise tools.CallException(msg.code, msg.message)
|
||||
def callback_PassphraseRequest(self, msg):
|
||||
if msg.on_device:
|
||||
passphrase = None
|
||||
else:
|
||||
passphrase = self.ui.get_passphrase()
|
||||
return self.call_raw(proto.PassphraseAck(passphrase=passphrase))
|
||||
|
||||
def callback_PassphraseStateRequest(self, msg):
|
||||
self.state = msg.state
|
||||
return self.call_raw(proto.PassphraseStateAck())
|
||||
|
||||
def callback_ButtonRequest(self, msg):
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
self.transport.write(proto.ButtonAck())
|
||||
self.ui.button_request(msg.code)
|
||||
return self.transport.read()
|
||||
|
||||
@tools.session
|
||||
def call(self, msg):
|
||||
resp = self.call_raw(msg)
|
||||
while True:
|
||||
handler_name = "callback_{}".format(resp.__class__.__name__)
|
||||
handler = getattr(self, handler_name, None)
|
||||
if handler is None:
|
||||
break
|
||||
resp = handler(resp) # pylint: disable=E1102
|
||||
|
||||
if isinstance(resp, proto.Failure):
|
||||
if resp.code == proto.FailureType.ActionCancelled:
|
||||
raise exceptions.Cancelled
|
||||
raise exceptions.TrezorException(resp.code, resp.message)
|
||||
|
||||
return resp
|
||||
|
||||
def register_message(self, msg):
|
||||
"""Allow application to register custom protobuf message type"""
|
||||
mapping.register_message(msg)
|
||||
|
||||
|
||||
class TextUIMixin(object):
|
||||
# This class demonstrates easy test-based UI
|
||||
# integration between the device and wallet.
|
||||
# You can implement similar functionality
|
||||
# by implementing your own GuiMixin with
|
||||
# graphical widgets for every type of these callbacks.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(TextUIMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def print(text):
|
||||
print(text, file=sys.stderr)
|
||||
|
||||
def callback_ButtonRequest(self, msg):
|
||||
# log("Sending ButtonAck for %s " % get_buttonrequest_value(msg.code))
|
||||
return proto.ButtonAck()
|
||||
|
||||
def callback_RecoveryMatrix(self, msg):
|
||||
if self.recovery_matrix_first_pass:
|
||||
self.recovery_matrix_first_pass = False
|
||||
self.print(
|
||||
"Use the numeric keypad to describe positions. For the word list use only left and right keys."
|
||||
)
|
||||
self.print("Use backspace to correct an entry. The keypad layout is:")
|
||||
self.print(" 7 8 9 7 | 9")
|
||||
self.print(" 4 5 6 4 | 6")
|
||||
self.print(" 1 2 3 1 | 3")
|
||||
while True:
|
||||
character = getch()
|
||||
if character in ("\x03", "\x04"):
|
||||
return proto.Cancel()
|
||||
|
||||
if character in ("\x08", "\x7f"):
|
||||
return proto.WordAck(word="\x08")
|
||||
|
||||
# ignore middle column if only 6 keys requested.
|
||||
if msg.type == proto.WordRequestType.Matrix6 and character in (
|
||||
"2",
|
||||
"5",
|
||||
"8",
|
||||
):
|
||||
continue
|
||||
|
||||
if character.isdigit():
|
||||
return proto.WordAck(word=character)
|
||||
|
||||
def callback_PinMatrixRequest(self, msg):
|
||||
if msg.type == proto.PinMatrixRequestType.Current:
|
||||
desc = "current PIN"
|
||||
elif msg.type == proto.PinMatrixRequestType.NewFirst:
|
||||
desc = "new PIN"
|
||||
elif msg.type == proto.PinMatrixRequestType.NewSecond:
|
||||
desc = "new PIN again"
|
||||
else:
|
||||
desc = "PIN"
|
||||
|
||||
self.print(
|
||||
"Use the numeric keypad to describe number positions. The layout is:"
|
||||
)
|
||||
self.print(" 7 8 9")
|
||||
self.print(" 4 5 6")
|
||||
self.print(" 1 2 3")
|
||||
self.print("Please enter %s: " % desc)
|
||||
pin = getpass.getpass("")
|
||||
if not pin.isdigit():
|
||||
raise ValueError("Non-numerical PIN provided")
|
||||
return proto.PinMatrixAck(pin=pin)
|
||||
|
||||
def callback_PassphraseRequest(self, msg):
|
||||
if msg.on_device is True:
|
||||
return proto.PassphraseAck()
|
||||
|
||||
if os.getenv("PASSPHRASE") is not None:
|
||||
self.print("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
passphrase = Mnemonic.normalize_string(os.getenv("PASSPHRASE"))
|
||||
return proto.PassphraseAck(passphrase=passphrase)
|
||||
|
||||
self.print("Passphrase required: ")
|
||||
passphrase = getpass.getpass("")
|
||||
self.print("Confirm your Passphrase: ")
|
||||
if passphrase == getpass.getpass(""):
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
return proto.PassphraseAck(passphrase=passphrase)
|
||||
else:
|
||||
self.print("Passphrase did not match! ")
|
||||
exit()
|
||||
|
||||
def callback_PassphraseStateRequest(self, msg):
|
||||
return proto.PassphraseStateAck()
|
||||
|
||||
def callback_WordRequest(self, msg):
|
||||
if msg.type in (proto.WordRequestType.Matrix9, proto.WordRequestType.Matrix6):
|
||||
return self.callback_RecoveryMatrix(msg)
|
||||
self.print("Enter one word of mnemonic: ")
|
||||
word = input()
|
||||
if self.expand:
|
||||
word = self.mnemonic_wordlist.expand_word(word)
|
||||
return proto.WordAck(word=word)
|
||||
|
||||
|
||||
class DebugLinkMixin(object):
|
||||
# This class implements automatic responses
|
||||
# and other functionality for unit tests
|
||||
# for various callbacks, created in order
|
||||
# to automatically pass unit tests.
|
||||
#
|
||||
# This mixing should be used only for purposes
|
||||
# of unit testing, because it will fail to work
|
||||
# without special DebugLink interface provided
|
||||
# by the device.
|
||||
DEBUG = LOG.getChild("debug_link").debug
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DebugLinkMixin, self).__init__(*args, **kwargs)
|
||||
self.debug = None
|
||||
self.in_with_statement = 0
|
||||
self.button_wait = 0
|
||||
self.screenshot_id = 0
|
||||
|
||||
# Always press Yes and provide correct pin
|
||||
self.setup_debuglink(True, True)
|
||||
|
||||
# Do not expect any specific response from device
|
||||
self.expected_responses = None
|
||||
|
||||
# Use blank passphrase
|
||||
self.set_passphrase("")
|
||||
|
||||
def close(self):
|
||||
super(DebugLinkMixin, self).close()
|
||||
if self.debug:
|
||||
self.debug.close()
|
||||
|
||||
def set_debuglink(self, debug_transport):
|
||||
self.debug = debuglink.DebugLink(debug_transport)
|
||||
|
||||
def set_buttonwait(self, secs):
|
||||
self.button_wait = secs
|
||||
|
||||
def __enter__(self):
|
||||
# For usage in with/expected_responses
|
||||
self.in_with_statement += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, _type, value, traceback):
|
||||
self.in_with_statement -= 1
|
||||
|
||||
if _type is not None:
|
||||
# Another exception raised
|
||||
return False
|
||||
|
||||
# return isinstance(value, TypeError)
|
||||
# Evaluate missed responses in 'with' statement
|
||||
if self.expected_responses is not None and len(self.expected_responses):
|
||||
raise RuntimeError(
|
||||
"Some of expected responses didn't come from device: %s"
|
||||
% [repr(x) for x in self.expected_responses]
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
self.expected_responses = None
|
||||
return False
|
||||
|
||||
def set_expected_responses(self, expected):
|
||||
if not self.in_with_statement:
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
self.expected_responses = expected
|
||||
|
||||
def setup_debuglink(self, button, pin_correct):
|
||||
self.button = button # True -> YES button, False -> NO button
|
||||
self.pin_correct = pin_correct
|
||||
|
||||
def set_passphrase(self, passphrase):
|
||||
self.passphrase = Mnemonic.normalize_string(passphrase)
|
||||
|
||||
def set_mnemonic(self, mnemonic):
|
||||
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
||||
|
||||
def call_raw(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
if SCREENSHOT and self.debug:
|
||||
from PIL import Image
|
||||
|
||||
layout = self.debug.read_layout()
|
||||
im = Image.new("RGB", (128, 64))
|
||||
pix = im.load()
|
||||
for x in range(128):
|
||||
for y in range(64):
|
||||
rx, ry = 127 - x, 63 - y
|
||||
if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
|
||||
pix[x, y] = (255, 255, 255)
|
||||
im.save("scr%05d.png" % self.screenshot_id)
|
||||
self.screenshot_id += 1
|
||||
|
||||
resp = super(DebugLinkMixin, self).call_raw(msg)
|
||||
self._check_request(resp)
|
||||
return resp
|
||||
|
||||
def _check_request(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
if self.expected_responses is not None:
|
||||
try:
|
||||
expected = self.expected_responses.pop(0)
|
||||
except IndexError:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Got %s, but no message has been expected" % repr(msg),
|
||||
)
|
||||
|
||||
if msg.__class__ != expected.__class__:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
||||
)
|
||||
|
||||
for field, value in expected.__dict__.items():
|
||||
if value is None or value == []:
|
||||
continue
|
||||
if getattr(msg, field) != value:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
||||
)
|
||||
|
||||
def callback_ButtonRequest(self, msg):
|
||||
self.DEBUG("ButtonRequest code: " + get_buttonrequest_value(msg.code))
|
||||
|
||||
self.DEBUG("Pressing button " + str(self.button))
|
||||
if self.button_wait:
|
||||
self.DEBUG("Waiting %d seconds " % self.button_wait)
|
||||
time.sleep(self.button_wait)
|
||||
self.debug.press_button(self.button)
|
||||
return proto.ButtonAck()
|
||||
|
||||
def callback_PinMatrixRequest(self, msg):
|
||||
if self.pin_correct:
|
||||
pin = self.debug.read_pin_encoded()
|
||||
else:
|
||||
pin = "444222"
|
||||
return proto.PinMatrixAck(pin=pin)
|
||||
|
||||
def callback_PassphraseRequest(self, msg):
|
||||
self.DEBUG("Provided passphrase: '%s'" % self.passphrase)
|
||||
return proto.PassphraseAck(passphrase=self.passphrase)
|
||||
|
||||
def callback_PassphraseStateRequest(self, msg):
|
||||
return proto.PassphraseStateAck()
|
||||
|
||||
def callback_WordRequest(self, msg):
|
||||
(word, pos) = self.debug.read_recovery_word()
|
||||
if word != "":
|
||||
return proto.WordAck(word=word)
|
||||
if pos != 0:
|
||||
return proto.WordAck(word=self.mnemonic[pos - 1])
|
||||
|
||||
raise RuntimeError("Unexpected call")
|
||||
|
||||
|
||||
class ProtocolMixin(object):
|
||||
VENDORS = ("bitcointrezor.com", "trezor.io")
|
||||
|
||||
@ -442,10 +171,11 @@ class ProtocolMixin(object):
|
||||
self.tx_api = tx_api
|
||||
|
||||
def init_device(self):
|
||||
init_msg = proto.Initialize()
|
||||
if self.state is not None:
|
||||
init_msg.state = self.state
|
||||
self.features = tools.expect(proto.Features)(self.call)(init_msg)
|
||||
resp = self.call(proto.Initialize(state=self.state))
|
||||
if not isinstance(resp, proto.Features):
|
||||
raise exceptions.TrezorException("Unexpected initial response")
|
||||
else:
|
||||
self.features = resp
|
||||
if str(self.features.vendor) not in self.VENDORS:
|
||||
raise RuntimeError("Unsupported device")
|
||||
|
||||
@ -512,11 +242,6 @@ class ProtocolMixin(object):
|
||||
reset_device = MovedTo(device.reset)
|
||||
backup_device = MovedTo(device.backup)
|
||||
|
||||
# debugging
|
||||
load_device_by_mnemonic = MovedTo(debuglink.load_device_by_mnemonic)
|
||||
load_device_by_xprv = MovedTo(debuglink.load_device_by_xprv)
|
||||
self_test = MovedTo(debuglink.self_test)
|
||||
|
||||
set_u2f_counter = MovedTo(device.set_u2f_counter)
|
||||
|
||||
apply_settings = MovedTo(device.apply_settings)
|
||||
@ -566,7 +291,7 @@ class ProtocolMixin(object):
|
||||
decrypt_keyvalue = MovedTo(misc.decrypt_keyvalue)
|
||||
|
||||
|
||||
class TrezorClient(ProtocolMixin, TextUIMixin, BaseClient):
|
||||
class TrezorClient(ProtocolMixin, BaseClient):
|
||||
def __init__(self, transport, *args, **kwargs):
|
||||
super().__init__(transport=transport, *args, **kwargs)
|
||||
|
||||
|
@ -18,25 +18,15 @@
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import messages as proto, tools
|
||||
from .client import TrezorClient
|
||||
from .tools import expect
|
||||
|
||||
|
||||
def pin_info(pin):
|
||||
print("Device asks for PIN %s" % pin)
|
||||
|
||||
|
||||
def button_press(yes_no):
|
||||
print("User pressed", '"y"' if yes_no else '"n"')
|
||||
|
||||
|
||||
class DebugLink(object):
|
||||
def __init__(self, transport, pin_func=pin_info, button_func=button_press):
|
||||
class DebugLink:
|
||||
def __init__(self, transport):
|
||||
self.transport = transport
|
||||
self.transport.session_begin()
|
||||
|
||||
self.pin_func = pin_func
|
||||
self.button_func = button_func
|
||||
|
||||
def close(self):
|
||||
self.transport.session_end()
|
||||
|
||||
@ -47,30 +37,21 @@ class DebugLink(object):
|
||||
ret = self.transport.read()
|
||||
return ret
|
||||
|
||||
def read_pin(self):
|
||||
obj = self._call(proto.DebugLinkGetState())
|
||||
print("Read PIN:", obj.pin)
|
||||
print("Read matrix:", obj.matrix)
|
||||
def state(self):
|
||||
return self._call(proto.DebugLinkGetState())
|
||||
|
||||
return (obj.pin, obj.matrix)
|
||||
def read_pin(self):
|
||||
state = self.state()
|
||||
return state.pin, state.matrix
|
||||
|
||||
def read_pin_encoded(self):
|
||||
pin, _ = self.read_pin()
|
||||
pin_encoded = self.encode_pin(pin)
|
||||
self.pin_func(pin_encoded)
|
||||
return pin_encoded
|
||||
return self.encode_pin(*self.read_pin())
|
||||
|
||||
def encode_pin(self, pin):
|
||||
_, matrix = self.read_pin()
|
||||
|
||||
# Now we have real PIN and PIN matrix.
|
||||
# We have to encode that into encoded pin,
|
||||
# because application must send back positions
|
||||
# on keypad, not a real PIN.
|
||||
pin_encoded = "".join([str(matrix.index(p) + 1) for p in pin])
|
||||
|
||||
print("Encoded PIN:", pin_encoded)
|
||||
return pin_encoded
|
||||
def encode_pin(self, pin, matrix=None):
|
||||
"""Transform correct PIN according to the displayed matrix."""
|
||||
if matrix is None:
|
||||
_, matrix = self.read_pin()
|
||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||
|
||||
def read_layout(self):
|
||||
obj = self._call(proto.DebugLinkGetState())
|
||||
@ -80,10 +61,6 @@ class DebugLink(object):
|
||||
obj = self._call(proto.DebugLinkGetState())
|
||||
return obj.mnemonic
|
||||
|
||||
def read_node(self):
|
||||
obj = self._call(proto.DebugLinkGetState())
|
||||
return obj.node
|
||||
|
||||
def read_recovery_word(self):
|
||||
obj = self._call(proto.DebugLinkGetState())
|
||||
return (obj.recovery_fake_word, obj.recovery_word_pos)
|
||||
@ -104,36 +81,39 @@ class DebugLink(object):
|
||||
obj = self._call(proto.DebugLinkGetState())
|
||||
return obj.passphrase_protection
|
||||
|
||||
def input(self, word=None, button=None, swipe=None):
|
||||
decision = proto.DebugLinkDecision()
|
||||
if button is not None:
|
||||
decision.yes_no = button
|
||||
elif word is not None:
|
||||
decision.input = word
|
||||
elif swipe is not None:
|
||||
decision.up_down = swipe
|
||||
else:
|
||||
raise ValueError("You need to provide input data.")
|
||||
self._call(decision, nowait=True)
|
||||
|
||||
def press_button(self, yes_no):
|
||||
print("Pressing", yes_no)
|
||||
self.button_func(yes_no)
|
||||
self._call(proto.DebugLinkDecision(yes_no=yes_no), nowait=True)
|
||||
|
||||
def press_yes(self):
|
||||
self.press_button(True)
|
||||
self.input(button=True)
|
||||
|
||||
def press_no(self):
|
||||
self.press_button(False)
|
||||
|
||||
def swipe(self, up_down):
|
||||
print("Swiping", up_down)
|
||||
self._call(proto.DebugLinkDecision(up_down=up_down), nowait=True)
|
||||
self.input(button=False)
|
||||
|
||||
def swipe_up(self):
|
||||
self.swipe(True)
|
||||
self.input(swipe=True)
|
||||
|
||||
def swipe_down(self):
|
||||
self.swipe(False)
|
||||
|
||||
def input(self, text):
|
||||
self._call(proto.DebugLinkDecision(input=text), nowait=True)
|
||||
self.input(swipe=False)
|
||||
|
||||
def stop(self):
|
||||
self._call(proto.DebugLinkStop(), nowait=True)
|
||||
|
||||
@expect(proto.DebugLinkMemory, field="memory")
|
||||
def memory_read(self, address, length):
|
||||
obj = self._call(proto.DebugLinkMemoryRead(address=address, length=length))
|
||||
return obj.memory
|
||||
return self._call(proto.DebugLinkMemoryRead(address=address, length=length))
|
||||
|
||||
def memory_write(self, address, memory, flash=False):
|
||||
self._call(
|
||||
@ -145,6 +125,163 @@ class DebugLink(object):
|
||||
self._call(proto.DebugLinkFlashErase(sector=sector), nowait=True)
|
||||
|
||||
|
||||
class DebugUI:
|
||||
def __init__(self, debuglink: DebugLink):
|
||||
self.debuglink = debuglink
|
||||
self.pin = None
|
||||
self.passphrase = "sphinx of black quartz, judge my wov"
|
||||
|
||||
def button_request(self):
|
||||
self.debuglink.press_yes()
|
||||
|
||||
def get_pin(self):
|
||||
if self.pin:
|
||||
return self.pin
|
||||
else:
|
||||
return self.debuglink.read_pin_encoded()
|
||||
|
||||
def get_passphrase(self):
|
||||
return self.passphrase
|
||||
|
||||
|
||||
class TrezorClientDebugLink(TrezorClient):
|
||||
# This class implements automatic responses
|
||||
# and other functionality for unit tests
|
||||
# for various callbacks, created in order
|
||||
# to automatically pass unit tests.
|
||||
#
|
||||
# This mixing should be used only for purposes
|
||||
# of unit testing, because it will fail to work
|
||||
# without special DebugLink interface provided
|
||||
# by the device.
|
||||
DEBUG = LOG.getChild("debug_link").debug
|
||||
|
||||
def __init__(self, transport):
|
||||
self.debug = DebugLink(transport.find_debug())
|
||||
self.ui = DebugUI(self.debug)
|
||||
super().__init__(transport, self.ui)
|
||||
|
||||
self.in_with_statement = 0
|
||||
self.button_wait = 0
|
||||
self.screenshot_id = 0
|
||||
|
||||
# Always press Yes and provide correct pin
|
||||
self.setup_debuglink(True, True)
|
||||
|
||||
# Do not expect any specific response from device
|
||||
self.expected_responses = None
|
||||
|
||||
# Use blank passphrase
|
||||
self.set_passphrase("")
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
if self.debug:
|
||||
self.debug.close()
|
||||
|
||||
def set_buttonwait(self, secs):
|
||||
self.button_wait = secs
|
||||
|
||||
def __enter__(self):
|
||||
# For usage in with/expected_responses
|
||||
self.in_with_statement += 1
|
||||
return self
|
||||
|
||||
def __exit__(self, _type, value, traceback):
|
||||
self.in_with_statement -= 1
|
||||
|
||||
if _type is not None:
|
||||
# Another exception raised
|
||||
return False
|
||||
|
||||
# return isinstance(value, TypeError)
|
||||
# Evaluate missed responses in 'with' statement
|
||||
if self.expected_responses is not None and len(self.expected_responses):
|
||||
raise RuntimeError(
|
||||
"Some of expected responses didn't come from device: %s"
|
||||
% [repr(x) for x in self.expected_responses]
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
self.expected_responses = None
|
||||
return False
|
||||
|
||||
def set_expected_responses(self, expected):
|
||||
if not self.in_with_statement:
|
||||
raise RuntimeError("Must be called inside 'with' statement")
|
||||
self.expected_responses = expected
|
||||
|
||||
def setup_debuglink(self, button, pin_correct):
|
||||
# self.button = button # True -> YES button, False -> NO button
|
||||
if pin_correct:
|
||||
self.ui.pin = None
|
||||
else:
|
||||
self.ui.pin = "444222"
|
||||
|
||||
def set_passphrase(self, passphrase):
|
||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||
|
||||
def set_mnemonic(self, mnemonic):
|
||||
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
||||
|
||||
def call_raw(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
# if SCREENSHOT and self.debug:
|
||||
# from PIL import Image
|
||||
|
||||
# layout = self.debug.state().layout
|
||||
# im = Image.new("RGB", (128, 64))
|
||||
# pix = im.load()
|
||||
# for x in range(128):
|
||||
# for y in range(64):
|
||||
# rx, ry = 127 - x, 63 - y
|
||||
# if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
|
||||
# pix[x, y] = (255, 255, 255)
|
||||
# im.save("scr%05d.png" % self.screenshot_id)
|
||||
# self.screenshot_id += 1
|
||||
|
||||
resp = super().call_raw(msg)
|
||||
self._check_request(resp)
|
||||
return resp
|
||||
|
||||
def _check_request(self, msg):
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
if self.expected_responses is not None:
|
||||
try:
|
||||
expected = self.expected_responses.pop(0)
|
||||
except IndexError:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Got %s, but no message has been expected" % repr(msg),
|
||||
)
|
||||
|
||||
if msg.__class__ != expected.__class__:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
||||
)
|
||||
|
||||
for field, value in expected.__dict__.items():
|
||||
if value is None or value == []:
|
||||
continue
|
||||
if getattr(msg, field) != value:
|
||||
raise AssertionError(
|
||||
proto.FailureType.UnexpectedMessage,
|
||||
"Expected %s, got %s" % (repr(expected), repr(msg)),
|
||||
)
|
||||
|
||||
def mnemonic_callback(self, _):
|
||||
word, pos = self.debug.read_recovery_word()
|
||||
if word != "":
|
||||
return word
|
||||
if pos != 0:
|
||||
return self.mnemonic[pos - 1]
|
||||
|
||||
raise RuntimeError("Unexpected call")
|
||||
|
||||
|
||||
@expect(proto.Success, field="message")
|
||||
def load_device_by_mnemonic(
|
||||
client,
|
||||
@ -172,7 +309,7 @@ def load_device_by_mnemonic(
|
||||
|
||||
if client.features.initialized:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call wipe_device() and try again."
|
||||
"Device is initialized already. Call device.wipe() and try again."
|
||||
)
|
||||
|
||||
resp = client.call(
|
||||
|
@ -22,6 +22,9 @@ from mnemonic import Mnemonic
|
||||
from . import messages as proto
|
||||
from .tools import expect, session
|
||||
from .transport import enumerate_devices, get_transport
|
||||
from .exceptions import Cancelled
|
||||
|
||||
RECOVERY_BACK = "\x08" # backspace character, sent literally
|
||||
|
||||
|
||||
class TrezorDevice:
|
||||
@ -106,24 +109,17 @@ def recover(
|
||||
pin_protection,
|
||||
label,
|
||||
language,
|
||||
input_callback,
|
||||
type=proto.RecoveryDeviceType.ScrambledWords,
|
||||
expand=False,
|
||||
dry_run=False,
|
||||
):
|
||||
if client.features.initialized and not dry_run:
|
||||
raise RuntimeError(
|
||||
"Device is initialized already. Call wipe_device() and try again."
|
||||
)
|
||||
|
||||
if word_count not in (12, 18, 24):
|
||||
raise ValueError("Invalid word count. Use 12/18/24")
|
||||
|
||||
client.recovery_matrix_first_pass = True
|
||||
|
||||
client.expand = expand
|
||||
if client.expand:
|
||||
# optimization to load the wordlist once, instead of for each recovery word
|
||||
client.mnemonic_wordlist = Mnemonic("english")
|
||||
if client.features.initialized and not dry_run:
|
||||
raise RuntimeError(
|
||||
"Device already initialized. Call device.wipe() and try again."
|
||||
)
|
||||
|
||||
res = client.call(
|
||||
proto.RecoveryDevice(
|
||||
@ -138,6 +134,13 @@ def recover(
|
||||
)
|
||||
)
|
||||
|
||||
while isinstance(res, proto.WordRequest):
|
||||
try:
|
||||
inp = input_callback(res.type)
|
||||
res = client.call(proto.WordAck(word=inp))
|
||||
except Cancelled:
|
||||
res = client.call(proto.Cancel())
|
||||
|
||||
client.init_device()
|
||||
return res
|
||||
|
||||
|
10
trezorlib/exceptions.py
Normal file
10
trezorlib/exceptions.py
Normal file
@ -0,0 +1,10 @@
|
||||
class TrezorException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PinException(TrezorException):
|
||||
pass
|
||||
|
||||
|
||||
class Cancelled(TrezorException):
|
||||
pass
|
@ -17,7 +17,8 @@
|
||||
import os
|
||||
|
||||
from trezorlib import coins, debuglink, device, tx_api
|
||||
from trezorlib.client import TrezorClientDebugLink
|
||||
from trezorlib.messages.PassphraseSourceType import HOST as PASSPHRASE_ON_HOST
|
||||
from trezorlib.debuglink import TrezorClientDebugLink
|
||||
|
||||
from . import conftest
|
||||
|
||||
@ -40,9 +41,7 @@ class TrezorTest:
|
||||
|
||||
def setup_method(self, method):
|
||||
wirelink = conftest.get_device()
|
||||
debuglink = wirelink.find_debug()
|
||||
self.client = TrezorClientDebugLink(wirelink)
|
||||
self.client.set_debuglink(debuglink)
|
||||
self.client.set_tx_api(coins.tx_api["Bitcoin"])
|
||||
# self.client.set_buttonwait(3)
|
||||
|
||||
@ -64,6 +63,8 @@ class TrezorTest:
|
||||
label="test",
|
||||
language="english",
|
||||
)
|
||||
if passphrase:
|
||||
device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST)
|
||||
|
||||
def setup_mnemonic_allallall(self):
|
||||
self._setup_mnemonic(mnemonic=TrezorTest.mnemonic_all)
|
||||
|
@ -24,12 +24,12 @@ from .common import TrezorTest
|
||||
@pytest.mark.skip_t2
|
||||
class TestDebuglink(TrezorTest):
|
||||
def test_layout(self):
|
||||
layout = self.client.debug.read_layout()
|
||||
layout = self.client.debug.state().layout
|
||||
assert len(layout) == 1024
|
||||
|
||||
def test_mnemonic(self):
|
||||
self.setup_mnemonic_nopin_nopassphrase()
|
||||
mnemonic = self.client.debug.read_mnemonic()
|
||||
mnemonic = self.client.debug.state().mnemonic
|
||||
assert mnemonic == self.mnemonic12
|
||||
|
||||
def test_pin(self):
|
||||
@ -39,9 +39,9 @@ class TestDebuglink(TrezorTest):
|
||||
resp = self.client.call_raw(proto.Ping(message="test", pin_protection=True))
|
||||
assert isinstance(resp, proto.PinMatrixRequest)
|
||||
|
||||
pin = self.client.debug.read_pin()
|
||||
assert pin[0] == "1234"
|
||||
assert pin[1] != ""
|
||||
pin, matrix = self.client.debug.read_pin()
|
||||
assert pin == "1234"
|
||||
assert matrix != ""
|
||||
|
||||
pin_encoded = self.client.debug.read_pin_encoded()
|
||||
resp = self.client.call_raw(proto.PinMatrixAck(pin=pin_encoded))
|
||||
|
@ -25,15 +25,10 @@ from .common import TrezorTest
|
||||
class TestDeviceLoad(TrezorTest):
|
||||
def test_load_device_1(self):
|
||||
self.setup_mnemonic_nopin_nopassphrase()
|
||||
|
||||
mnemonic = self.client.debug.read_mnemonic()
|
||||
assert mnemonic == self.mnemonic12
|
||||
|
||||
pin = self.client.debug.read_pin()[0]
|
||||
assert pin is None
|
||||
|
||||
passphrase_protection = self.client.debug.read_passphrase_protection()
|
||||
assert passphrase_protection is False
|
||||
state = self.client.debug.state()
|
||||
assert state.mnemonic == self.mnemonic12
|
||||
assert state.pin is None
|
||||
assert state.passphrase_protection is False
|
||||
|
||||
address = btc.get_address(self.client, "Bitcoin", [])
|
||||
assert address == "1EfKbQupktEMXf4gujJ9kCFo83k1iMqwqK"
|
||||
@ -41,15 +36,10 @@ class TestDeviceLoad(TrezorTest):
|
||||
def test_load_device_2(self):
|
||||
self.setup_mnemonic_pin_passphrase()
|
||||
self.client.set_passphrase("passphrase")
|
||||
|
||||
mnemonic = self.client.debug.read_mnemonic()
|
||||
assert mnemonic == self.mnemonic12
|
||||
|
||||
pin = self.client.debug.read_pin()[0]
|
||||
assert pin == self.pin4
|
||||
|
||||
passphrase_protection = self.client.debug.read_passphrase_protection()
|
||||
assert passphrase_protection is True
|
||||
state = self.client.debug.state()
|
||||
assert state.mnemonic == self.mnemonic12
|
||||
assert state.pin == self.pin4
|
||||
assert state.passphrase_protection is True
|
||||
|
||||
address = btc.get_address(self.client, "Bitcoin", [])
|
||||
assert address == "15fiTDFwZd2kauHYYseifGi9daH2wniDHH"
|
||||
|
@ -145,18 +145,35 @@ class TestProtectionLevels(TrezorTest):
|
||||
device.reset(self.client, False, 128, True, False, "label", "english")
|
||||
|
||||
def test_recovery_device(self):
|
||||
self.client.set_mnemonic(self.mnemonic12)
|
||||
with self.client:
|
||||
self.client.set_mnemonic(self.mnemonic12)
|
||||
self.client.set_expected_responses(
|
||||
[proto.ButtonRequest()]
|
||||
+ [proto.WordRequest()] * 24
|
||||
+ [proto.Success(), proto.Features()]
|
||||
)
|
||||
device.recover(self.client, 12, False, False, "label", "english")
|
||||
|
||||
device.recover(
|
||||
self.client,
|
||||
12,
|
||||
False,
|
||||
False,
|
||||
"label",
|
||||
"english",
|
||||
self.client.mnemonic_callback,
|
||||
)
|
||||
|
||||
# This must fail, because device is already initialized
|
||||
with pytest.raises(Exception):
|
||||
device.recover(self.client, 12, False, False, "label", "english")
|
||||
with pytest.raises(RuntimeError):
|
||||
device.recover(
|
||||
self.client,
|
||||
12,
|
||||
False,
|
||||
False,
|
||||
"label",
|
||||
"english",
|
||||
self.client.mnemonic_callback,
|
||||
)
|
||||
|
||||
def test_sign_message(self):
|
||||
with self.client:
|
||||
|
@ -22,6 +22,9 @@ import unicodedata
|
||||
from typing import List, NewType
|
||||
|
||||
from .coins import slip44
|
||||
from .exceptions import TrezorException
|
||||
|
||||
CallException = TrezorException
|
||||
|
||||
HARDENED_FLAG = 1 << 31
|
||||
|
||||
@ -174,10 +177,6 @@ def normalize_nfc(txt):
|
||||
return unicodedata.normalize("NFC", txt).encode()
|
||||
|
||||
|
||||
class CallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class expect:
|
||||
# Decorator checks if the method
|
||||
# returned one of expected protobuf messages
|
||||
|
141
trezorlib/ui.py
Normal file
141
trezorlib/ui.py
Normal file
@ -0,0 +1,141 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from .messages import WordRequestType, PinMatrixRequestType
|
||||
from . import device
|
||||
from .exceptions import Cancelled
|
||||
|
||||
PIN_MATRIX_DESCRIPTION = """
|
||||
Use the numeric keypad to describe number positions. The layout is:
|
||||
7 8 9
|
||||
4 5 6
|
||||
1 2 3
|
||||
""".strip()
|
||||
|
||||
RECOVERY_MATRIX_DESCRIPTION = """
|
||||
Use the numeric keypad to describe positions.
|
||||
For the word list use only left and right keys.
|
||||
Use backspace to correct an entry.
|
||||
|
||||
The keypad layout is:
|
||||
7 8 9 7 | 9
|
||||
4 5 6 4 | 6
|
||||
1 2 3 1 | 3
|
||||
""".strip()
|
||||
|
||||
PIN_GENERIC = None
|
||||
PIN_CURRENT = PinMatrixRequestType.Current
|
||||
PIN_NEW = PinMatrixRequestType.NewFirst
|
||||
PIN_CONFIRM = PinMatrixRequestType.NewSecond
|
||||
|
||||
|
||||
class ClickUI:
|
||||
@staticmethod
|
||||
def button_request(code):
|
||||
click.echo("Please confirm action on your Trezor device")
|
||||
|
||||
@staticmethod
|
||||
def get_pin(code=None):
|
||||
if code == PIN_CURRENT:
|
||||
desc = "current PIN"
|
||||
elif code == PIN_NEW:
|
||||
desc = "new PIN"
|
||||
elif code == PIN_CONFIRM:
|
||||
desc = "new PIN again"
|
||||
else:
|
||||
desc = "PIN"
|
||||
click.echo(PIN_MATRIX_DESCRIPTION)
|
||||
|
||||
while True:
|
||||
pin = click.prompt("Please enter {}".format(desc), hide_input=True)
|
||||
if not pin.isdigit():
|
||||
click.echo("Non-numerical PIN provided, please try again")
|
||||
else:
|
||||
return pin
|
||||
|
||||
@staticmethod
|
||||
def get_passphrase():
|
||||
if os.getenv("PASSPHRASE") is not None:
|
||||
click.echo("Passphrase required. Using PASSPHRASE environment variable.")
|
||||
return os.getenv("PASSPHRASE")
|
||||
|
||||
while True:
|
||||
passphrase = click.prompt("Passphrase required", hide_input=True)
|
||||
second = click.prompt("Confirm your passphrase", hide_input=True)
|
||||
if passphrase == second:
|
||||
return passphrase
|
||||
else:
|
||||
click.echo("Passphrase did not match. Please try again.")
|
||||
|
||||
|
||||
def mnemonic_words(expand=False, language="english"):
|
||||
if expand:
|
||||
wordlist = Mnemonic(language).wordlist
|
||||
else:
|
||||
wordlist = set()
|
||||
|
||||
def expand_word(word):
|
||||
if not expand:
|
||||
return word
|
||||
if word in wordlist:
|
||||
return word
|
||||
matches = [w for w in wordlist if w.startswith(word)]
|
||||
if len(matches) == 1:
|
||||
return word
|
||||
click.echo("Choose one of: " + ", ".join(matches))
|
||||
raise KeyError(word)
|
||||
|
||||
def get_word(type):
|
||||
assert type == WordRequestType.Plain
|
||||
while True:
|
||||
try:
|
||||
word = click.prompt("Enter one word of mnemonic")
|
||||
return expand_word(word)
|
||||
except KeyError:
|
||||
pass
|
||||
except (KeyboardInterrupt, click.Abort):
|
||||
raise Cancelled from None
|
||||
|
||||
return get_word
|
||||
|
||||
|
||||
try:
|
||||
# workaround for Click issue https://github.com/pallets/click/pull/1108
|
||||
import msvcrt
|
||||
|
||||
def getchar():
|
||||
while True:
|
||||
key = msvcrt.getwch()
|
||||
if key == "\x03":
|
||||
raise KeyboardInterrupt
|
||||
if key in (0x00, 0xe0):
|
||||
# skip special keys: read the scancode and repeat
|
||||
msvcrt.getwch()
|
||||
continue
|
||||
return key
|
||||
|
||||
|
||||
except ImportError:
|
||||
getchar = click.getchar
|
||||
|
||||
|
||||
def matrix_words(type):
|
||||
while True:
|
||||
try:
|
||||
ch = getchar()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
raise Cancelled from None
|
||||
|
||||
if ch in "\x04\x1b":
|
||||
# Ctrl+D, Esc
|
||||
raise Cancelled
|
||||
if ch in "\x08\x7f":
|
||||
# Backspace, Del
|
||||
return device.RECOVERY_BACK
|
||||
if type == WordRequestType.Matrix6 and ch in "147369":
|
||||
return ch
|
||||
if type == WordRequestType.Matrix9 and ch in "123456789":
|
||||
return ch
|
||||
|
Loading…
Reference in New Issue
Block a user