1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 07:28:10 +00:00

trezorlib: get rid of TextUIMixin

This also moves DebugLinkMixin to debuglink.py and converts the mixin to
a subclass of TrezorClient (which is finally becoming a
reasonable-looking class). This takes advantage of the new UI protocol
and is ready for further improvements, namely, queuing input for tests
that require swipes.

The ui.py module contains a Click-based implementation of the UI
protocol. Use of callback_* methods has been limited and will probably
be cleaned up further (The contract has changed so we'll try to make
third party code fail noisily. It is unclear whether a backwards
compatible approach will be possible).

Furthermore, device.recovery() now takes a callback as an argument. This
way we can get rid of WordRequest callbacks, which are only used in the
recovery flow.
This commit is contained in:
matejcik 2018-09-13 18:47:19 +02:00
parent 6d9157c4a5
commit 06927e003e
11 changed files with 461 additions and 430 deletions

View File

@ -49,6 +49,7 @@ from trezorlib import (
stellar,
tezos,
tools,
ui,
)
from trezorlib.client import TrezorClient
from trezorlib.transport import enumerate_devices, get_transport
@ -120,7 +121,7 @@ def cli(ctx, path, verbose, is_json):
if path is not None:
click.echo("Using path: {}".format(path))
sys.exit(1)
return TrezorClient(transport=device)
return TrezorClient(transport=device, ui=ui.ClickUI)
ctx.obj = get_device
@ -214,6 +215,7 @@ def get_features(connect):
@click.option("-r", "--remove", is_flag=True)
@click.pass_obj
def change_pin(connect, remove):
click.echo(ui.PIN_MATRIX_DESCRIPTION)
return device.change_pin(connect(), remove)
@ -326,7 +328,7 @@ def wipe_device(connect, bootloader):
sys.exit(1)
else:
click.echo(
"Wiping user data and firmware! Please confirm the action on your device ..."
"Wiping user data and firmware!"
)
else:
if client.features.bootloader_mode:
@ -339,7 +341,7 @@ def wipe_device(connect, bootloader):
click.echo("Aborting.")
sys.exit(1)
else:
click.echo("Wiping user data! Please confirm the action on your device ...")
click.echo("Wiping user data!")
try:
return device.wipe(connect())
@ -417,6 +419,12 @@ def recovery_device(
rec_type,
dry_run,
):
if rec_type == proto.RecoveryDeviceType.ScrambledWords:
input_callback = ui.mnemonic_words(expand)
else:
input_callback = ui.matrix_words
click.echo(ui.RECOVERY_MATRIX_DESCRIPTION)
return device.recover(
connect(),
int(words),
@ -424,8 +432,8 @@ def recovery_device(
pin_protection,
label,
"english",
input_callback,
rec_type,
expand,
dry_run,
)

View File

@ -30,6 +30,7 @@ from . import (
debuglink,
device,
ethereum,
exceptions,
firmware,
lisk,
mapping,
@ -47,38 +48,7 @@ if sys.version_info.major < 3:
SCREENSHOT = False
LOG = logging.getLogger(__name__)
# make a getch function
try:
import termios
import tty
# POSIX system. Create and return a getch that manipulates the tty.
# On Windows, termios will fail to import.
def getch():
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(fd)
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch
except ImportError:
# Windows system.
# Use msvcrt's getch function.
import msvcrt
def getch():
while True:
key = msvcrt.getch()
if key in (0x00, 0xE0):
# skip special keys: read the scancode and repeat
msvcrt.getch()
continue
return key.decode()
PinException = exceptions.PinException
def get_buttonrequest_value(code):
@ -90,10 +60,6 @@ def get_buttonrequest_value(code):
][0]
class PinException(tools.CallException):
pass
class MovedTo:
"""Deprecation redirector for methods that were formerly part of TrezorClient"""
@ -120,9 +86,10 @@ class MovedTo:
class BaseClient(object):
# Implements very basic layer of sending raw protobuf
# messages to device and getting its response back.
def __init__(self, transport, **kwargs):
def __init__(self, transport, ui, **kwargs):
LOG.info("creating client instance for device: {}".format(transport.get_path()))
self.transport = transport
self.ui = ui
super(BaseClient, self).__init__() # *args, **kwargs)
def close(self):
@ -137,298 +104,60 @@ class BaseClient(object):
self.transport.write(msg)
return self.transport.read()
@tools.session
def call(self, msg):
resp = self.call_raw(msg)
handler_name = "callback_%s" % resp.__class__.__name__
handler = getattr(self, handler_name, None)
def callback_PinMatrixRequest(self, msg):
pin = self.ui.get_pin(msg.type)
if not pin.isdigit():
raise ValueError("Non-numeric PIN provided")
if handler is not None:
msg = handler(resp)
if msg is None:
raise ValueError(
"Callback %s must return protobuf message, not None" % handler
)
resp = self.call(msg)
return resp
def callback_Failure(self, msg):
if msg.code in (
resp = self.call_raw(proto.PinMatrixAck(pin=pin))
if isinstance(resp, proto.Failure) and resp.code in (
proto.FailureType.PinInvalid,
proto.FailureType.PinCancelled,
proto.FailureType.PinExpected,
):
raise PinException(msg.code, msg.message)
raise exceptions.PinException(msg.code, msg.message)
else:
return resp
raise tools.CallException(msg.code, msg.message)
def callback_PassphraseRequest(self, msg):
if msg.on_device:
passphrase = None
else:
passphrase = self.ui.get_passphrase()
return self.call_raw(proto.PassphraseAck(passphrase=passphrase))
def callback_PassphraseStateRequest(self, msg):
self.state = msg.state
return self.call_raw(proto.PassphraseStateAck())
def callback_ButtonRequest(self, msg):
# do this raw - send ButtonAck first, notify UI later
self.transport.write(proto.ButtonAck())
self.ui.button_request(msg.code)
return self.transport.read()
@tools.session
def call(self, msg):
resp = self.call_raw(msg)
while True:
handler_name = "callback_{}".format(resp.__class__.__name__)
handler = getattr(self, handler_name, None)
if handler is None:
break
resp = handler(resp) # pylint: disable=E1102
if isinstance(resp, proto.Failure):
if resp.code == proto.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorException(resp.code, resp.message)
return resp
def register_message(self, msg):
"""Allow application to register custom protobuf message type"""
mapping.register_message(msg)
class TextUIMixin(object):
# This class demonstrates easy test-based UI
# integration between the device and wallet.
# You can implement similar functionality
# by implementing your own GuiMixin with
# graphical widgets for every type of these callbacks.
def __init__(self, *args, **kwargs):
super(TextUIMixin, self).__init__(*args, **kwargs)
@staticmethod
def print(text):
print(text, file=sys.stderr)
def callback_ButtonRequest(self, msg):
# log("Sending ButtonAck for %s " % get_buttonrequest_value(msg.code))
return proto.ButtonAck()
def callback_RecoveryMatrix(self, msg):
if self.recovery_matrix_first_pass:
self.recovery_matrix_first_pass = False
self.print(
"Use the numeric keypad to describe positions. For the word list use only left and right keys."
)
self.print("Use backspace to correct an entry. The keypad layout is:")
self.print(" 7 8 9 7 | 9")
self.print(" 4 5 6 4 | 6")
self.print(" 1 2 3 1 | 3")
while True:
character = getch()
if character in ("\x03", "\x04"):
return proto.Cancel()
if character in ("\x08", "\x7f"):
return proto.WordAck(word="\x08")
# ignore middle column if only 6 keys requested.
if msg.type == proto.WordRequestType.Matrix6 and character in (
"2",
"5",
"8",
):
continue
if character.isdigit():
return proto.WordAck(word=character)
def callback_PinMatrixRequest(self, msg):
if msg.type == proto.PinMatrixRequestType.Current:
desc = "current PIN"
elif msg.type == proto.PinMatrixRequestType.NewFirst:
desc = "new PIN"
elif msg.type == proto.PinMatrixRequestType.NewSecond:
desc = "new PIN again"
else:
desc = "PIN"
self.print(
"Use the numeric keypad to describe number positions. The layout is:"
)
self.print(" 7 8 9")
self.print(" 4 5 6")
self.print(" 1 2 3")
self.print("Please enter %s: " % desc)
pin = getpass.getpass("")
if not pin.isdigit():
raise ValueError("Non-numerical PIN provided")
return proto.PinMatrixAck(pin=pin)
def callback_PassphraseRequest(self, msg):
if msg.on_device is True:
return proto.PassphraseAck()
if os.getenv("PASSPHRASE") is not None:
self.print("Passphrase required. Using PASSPHRASE environment variable.")
passphrase = Mnemonic.normalize_string(os.getenv("PASSPHRASE"))
return proto.PassphraseAck(passphrase=passphrase)
self.print("Passphrase required: ")
passphrase = getpass.getpass("")
self.print("Confirm your Passphrase: ")
if passphrase == getpass.getpass(""):
passphrase = Mnemonic.normalize_string(passphrase)
return proto.PassphraseAck(passphrase=passphrase)
else:
self.print("Passphrase did not match! ")
exit()
def callback_PassphraseStateRequest(self, msg):
return proto.PassphraseStateAck()
def callback_WordRequest(self, msg):
if msg.type in (proto.WordRequestType.Matrix9, proto.WordRequestType.Matrix6):
return self.callback_RecoveryMatrix(msg)
self.print("Enter one word of mnemonic: ")
word = input()
if self.expand:
word = self.mnemonic_wordlist.expand_word(word)
return proto.WordAck(word=word)
class DebugLinkMixin(object):
# This class implements automatic responses
# and other functionality for unit tests
# for various callbacks, created in order
# to automatically pass unit tests.
#
# This mixing should be used only for purposes
# of unit testing, because it will fail to work
# without special DebugLink interface provided
# by the device.
DEBUG = LOG.getChild("debug_link").debug
def __init__(self, *args, **kwargs):
super(DebugLinkMixin, self).__init__(*args, **kwargs)
self.debug = None
self.in_with_statement = 0
self.button_wait = 0
self.screenshot_id = 0
# Always press Yes and provide correct pin
self.setup_debuglink(True, True)
# Do not expect any specific response from device
self.expected_responses = None
# Use blank passphrase
self.set_passphrase("")
def close(self):
super(DebugLinkMixin, self).close()
if self.debug:
self.debug.close()
def set_debuglink(self, debug_transport):
self.debug = debuglink.DebugLink(debug_transport)
def set_buttonwait(self, secs):
self.button_wait = secs
def __enter__(self):
# For usage in with/expected_responses
self.in_with_statement += 1
return self
def __exit__(self, _type, value, traceback):
self.in_with_statement -= 1
if _type is not None:
# Another exception raised
return False
# return isinstance(value, TypeError)
# Evaluate missed responses in 'with' statement
if self.expected_responses is not None and len(self.expected_responses):
raise RuntimeError(
"Some of expected responses didn't come from device: %s"
% [repr(x) for x in self.expected_responses]
)
# Cleanup
self.expected_responses = None
return False
def set_expected_responses(self, expected):
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected
def setup_debuglink(self, button, pin_correct):
self.button = button # True -> YES button, False -> NO button
self.pin_correct = pin_correct
def set_passphrase(self, passphrase):
self.passphrase = Mnemonic.normalize_string(passphrase)
def set_mnemonic(self, mnemonic):
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def call_raw(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if SCREENSHOT and self.debug:
from PIL import Image
layout = self.debug.read_layout()
im = Image.new("RGB", (128, 64))
pix = im.load()
for x in range(128):
for y in range(64):
rx, ry = 127 - x, 63 - y
if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
pix[x, y] = (255, 255, 255)
im.save("scr%05d.png" % self.screenshot_id)
self.screenshot_id += 1
resp = super(DebugLinkMixin, self).call_raw(msg)
self._check_request(resp)
return resp
def _check_request(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if self.expected_responses is not None:
try:
expected = self.expected_responses.pop(0)
except IndexError:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Got %s, but no message has been expected" % repr(msg),
)
if msg.__class__ != expected.__class__:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
def callback_ButtonRequest(self, msg):
self.DEBUG("ButtonRequest code: " + get_buttonrequest_value(msg.code))
self.DEBUG("Pressing button " + str(self.button))
if self.button_wait:
self.DEBUG("Waiting %d seconds " % self.button_wait)
time.sleep(self.button_wait)
self.debug.press_button(self.button)
return proto.ButtonAck()
def callback_PinMatrixRequest(self, msg):
if self.pin_correct:
pin = self.debug.read_pin_encoded()
else:
pin = "444222"
return proto.PinMatrixAck(pin=pin)
def callback_PassphraseRequest(self, msg):
self.DEBUG("Provided passphrase: '%s'" % self.passphrase)
return proto.PassphraseAck(passphrase=self.passphrase)
def callback_PassphraseStateRequest(self, msg):
return proto.PassphraseStateAck()
def callback_WordRequest(self, msg):
(word, pos) = self.debug.read_recovery_word()
if word != "":
return proto.WordAck(word=word)
if pos != 0:
return proto.WordAck(word=self.mnemonic[pos - 1])
raise RuntimeError("Unexpected call")
class ProtocolMixin(object):
VENDORS = ("bitcointrezor.com", "trezor.io")
@ -442,10 +171,11 @@ class ProtocolMixin(object):
self.tx_api = tx_api
def init_device(self):
init_msg = proto.Initialize()
if self.state is not None:
init_msg.state = self.state
self.features = tools.expect(proto.Features)(self.call)(init_msg)
resp = self.call(proto.Initialize(state=self.state))
if not isinstance(resp, proto.Features):
raise exceptions.TrezorException("Unexpected initial response")
else:
self.features = resp
if str(self.features.vendor) not in self.VENDORS:
raise RuntimeError("Unsupported device")
@ -512,11 +242,6 @@ class ProtocolMixin(object):
reset_device = MovedTo(device.reset)
backup_device = MovedTo(device.backup)
# debugging
load_device_by_mnemonic = MovedTo(debuglink.load_device_by_mnemonic)
load_device_by_xprv = MovedTo(debuglink.load_device_by_xprv)
self_test = MovedTo(debuglink.self_test)
set_u2f_counter = MovedTo(device.set_u2f_counter)
apply_settings = MovedTo(device.apply_settings)
@ -566,7 +291,7 @@ class ProtocolMixin(object):
decrypt_keyvalue = MovedTo(misc.decrypt_keyvalue)
class TrezorClient(ProtocolMixin, TextUIMixin, BaseClient):
class TrezorClient(ProtocolMixin, BaseClient):
def __init__(self, transport, *args, **kwargs):
super().__init__(transport=transport, *args, **kwargs)

View File

@ -18,25 +18,15 @@
from mnemonic import Mnemonic
from . import messages as proto, tools
from .client import TrezorClient
from .tools import expect
def pin_info(pin):
print("Device asks for PIN %s" % pin)
def button_press(yes_no):
print("User pressed", '"y"' if yes_no else '"n"')
class DebugLink(object):
def __init__(self, transport, pin_func=pin_info, button_func=button_press):
class DebugLink:
def __init__(self, transport):
self.transport = transport
self.transport.session_begin()
self.pin_func = pin_func
self.button_func = button_func
def close(self):
self.transport.session_end()
@ -47,30 +37,21 @@ class DebugLink(object):
ret = self.transport.read()
return ret
def read_pin(self):
obj = self._call(proto.DebugLinkGetState())
print("Read PIN:", obj.pin)
print("Read matrix:", obj.matrix)
def state(self):
return self._call(proto.DebugLinkGetState())
return (obj.pin, obj.matrix)
def read_pin(self):
state = self.state()
return state.pin, state.matrix
def read_pin_encoded(self):
pin, _ = self.read_pin()
pin_encoded = self.encode_pin(pin)
self.pin_func(pin_encoded)
return pin_encoded
return self.encode_pin(*self.read_pin())
def encode_pin(self, pin):
_, matrix = self.read_pin()
# Now we have real PIN and PIN matrix.
# We have to encode that into encoded pin,
# because application must send back positions
# on keypad, not a real PIN.
pin_encoded = "".join([str(matrix.index(p) + 1) for p in pin])
print("Encoded PIN:", pin_encoded)
return pin_encoded
def encode_pin(self, pin, matrix=None):
"""Transform correct PIN according to the displayed matrix."""
if matrix is None:
_, matrix = self.read_pin()
return "".join([str(matrix.index(p) + 1) for p in pin])
def read_layout(self):
obj = self._call(proto.DebugLinkGetState())
@ -80,10 +61,6 @@ class DebugLink(object):
obj = self._call(proto.DebugLinkGetState())
return obj.mnemonic
def read_node(self):
obj = self._call(proto.DebugLinkGetState())
return obj.node
def read_recovery_word(self):
obj = self._call(proto.DebugLinkGetState())
return (obj.recovery_fake_word, obj.recovery_word_pos)
@ -104,36 +81,39 @@ class DebugLink(object):
obj = self._call(proto.DebugLinkGetState())
return obj.passphrase_protection
def input(self, word=None, button=None, swipe=None):
decision = proto.DebugLinkDecision()
if button is not None:
decision.yes_no = button
elif word is not None:
decision.input = word
elif swipe is not None:
decision.up_down = swipe
else:
raise ValueError("You need to provide input data.")
self._call(decision, nowait=True)
def press_button(self, yes_no):
print("Pressing", yes_no)
self.button_func(yes_no)
self._call(proto.DebugLinkDecision(yes_no=yes_no), nowait=True)
def press_yes(self):
self.press_button(True)
self.input(button=True)
def press_no(self):
self.press_button(False)
def swipe(self, up_down):
print("Swiping", up_down)
self._call(proto.DebugLinkDecision(up_down=up_down), nowait=True)
self.input(button=False)
def swipe_up(self):
self.swipe(True)
self.input(swipe=True)
def swipe_down(self):
self.swipe(False)
def input(self, text):
self._call(proto.DebugLinkDecision(input=text), nowait=True)
self.input(swipe=False)
def stop(self):
self._call(proto.DebugLinkStop(), nowait=True)
@expect(proto.DebugLinkMemory, field="memory")
def memory_read(self, address, length):
obj = self._call(proto.DebugLinkMemoryRead(address=address, length=length))
return obj.memory
return self._call(proto.DebugLinkMemoryRead(address=address, length=length))
def memory_write(self, address, memory, flash=False):
self._call(
@ -145,6 +125,163 @@ class DebugLink(object):
self._call(proto.DebugLinkFlashErase(sector=sector), nowait=True)
class DebugUI:
def __init__(self, debuglink: DebugLink):
self.debuglink = debuglink
self.pin = None
self.passphrase = "sphinx of black quartz, judge my wov"
def button_request(self):
self.debuglink.press_yes()
def get_pin(self):
if self.pin:
return self.pin
else:
return self.debuglink.read_pin_encoded()
def get_passphrase(self):
return self.passphrase
class TrezorClientDebugLink(TrezorClient):
# This class implements automatic responses
# and other functionality for unit tests
# for various callbacks, created in order
# to automatically pass unit tests.
#
# This mixing should be used only for purposes
# of unit testing, because it will fail to work
# without special DebugLink interface provided
# by the device.
DEBUG = LOG.getChild("debug_link").debug
def __init__(self, transport):
self.debug = DebugLink(transport.find_debug())
self.ui = DebugUI(self.debug)
super().__init__(transport, self.ui)
self.in_with_statement = 0
self.button_wait = 0
self.screenshot_id = 0
# Always press Yes and provide correct pin
self.setup_debuglink(True, True)
# Do not expect any specific response from device
self.expected_responses = None
# Use blank passphrase
self.set_passphrase("")
def close(self):
super().close()
if self.debug:
self.debug.close()
def set_buttonwait(self, secs):
self.button_wait = secs
def __enter__(self):
# For usage in with/expected_responses
self.in_with_statement += 1
return self
def __exit__(self, _type, value, traceback):
self.in_with_statement -= 1
if _type is not None:
# Another exception raised
return False
# return isinstance(value, TypeError)
# Evaluate missed responses in 'with' statement
if self.expected_responses is not None and len(self.expected_responses):
raise RuntimeError(
"Some of expected responses didn't come from device: %s"
% [repr(x) for x in self.expected_responses]
)
# Cleanup
self.expected_responses = None
return False
def set_expected_responses(self, expected):
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.expected_responses = expected
def setup_debuglink(self, button, pin_correct):
# self.button = button # True -> YES button, False -> NO button
if pin_correct:
self.ui.pin = None
else:
self.ui.pin = "444222"
def set_passphrase(self, passphrase):
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def set_mnemonic(self, mnemonic):
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def call_raw(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# if SCREENSHOT and self.debug:
# from PIL import Image
# layout = self.debug.state().layout
# im = Image.new("RGB", (128, 64))
# pix = im.load()
# for x in range(128):
# for y in range(64):
# rx, ry = 127 - x, 63 - y
# if (ord(layout[rx + (ry / 8) * 128]) & (1 << (ry % 8))) > 0:
# pix[x, y] = (255, 255, 255)
# im.save("scr%05d.png" % self.screenshot_id)
# self.screenshot_id += 1
resp = super().call_raw(msg)
self._check_request(resp)
return resp
def _check_request(self, msg):
__tracebackhide__ = True # for pytest # pylint: disable=W0612
if self.expected_responses is not None:
try:
expected = self.expected_responses.pop(0)
except IndexError:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Got %s, but no message has been expected" % repr(msg),
)
if msg.__class__ != expected.__class__:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
for field, value in expected.__dict__.items():
if value is None or value == []:
continue
if getattr(msg, field) != value:
raise AssertionError(
proto.FailureType.UnexpectedMessage,
"Expected %s, got %s" % (repr(expected), repr(msg)),
)
def mnemonic_callback(self, _):
word, pos = self.debug.read_recovery_word()
if word != "":
return word
if pos != 0:
return self.mnemonic[pos - 1]
raise RuntimeError("Unexpected call")
@expect(proto.Success, field="message")
def load_device_by_mnemonic(
client,
@ -172,7 +309,7 @@ def load_device_by_mnemonic(
if client.features.initialized:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
"Device is initialized already. Call device.wipe() and try again."
)
resp = client.call(

View File

@ -22,6 +22,9 @@ from mnemonic import Mnemonic
from . import messages as proto
from .tools import expect, session
from .transport import enumerate_devices, get_transport
from .exceptions import Cancelled
RECOVERY_BACK = "\x08" # backspace character, sent literally
class TrezorDevice:
@ -106,24 +109,17 @@ def recover(
pin_protection,
label,
language,
input_callback,
type=proto.RecoveryDeviceType.ScrambledWords,
expand=False,
dry_run=False,
):
if client.features.initialized and not dry_run:
raise RuntimeError(
"Device is initialized already. Call wipe_device() and try again."
)
if word_count not in (12, 18, 24):
raise ValueError("Invalid word count. Use 12/18/24")
client.recovery_matrix_first_pass = True
client.expand = expand
if client.expand:
# optimization to load the wordlist once, instead of for each recovery word
client.mnemonic_wordlist = Mnemonic("english")
if client.features.initialized and not dry_run:
raise RuntimeError(
"Device already initialized. Call device.wipe() and try again."
)
res = client.call(
proto.RecoveryDevice(
@ -138,6 +134,13 @@ def recover(
)
)
while isinstance(res, proto.WordRequest):
try:
inp = input_callback(res.type)
res = client.call(proto.WordAck(word=inp))
except Cancelled:
res = client.call(proto.Cancel())
client.init_device()
return res

10
trezorlib/exceptions.py Normal file
View File

@ -0,0 +1,10 @@
class TrezorException(Exception):
pass
class PinException(TrezorException):
pass
class Cancelled(TrezorException):
pass

View File

@ -17,7 +17,8 @@
import os
from trezorlib import coins, debuglink, device, tx_api
from trezorlib.client import TrezorClientDebugLink
from trezorlib.messages.PassphraseSourceType import HOST as PASSPHRASE_ON_HOST
from trezorlib.debuglink import TrezorClientDebugLink
from . import conftest
@ -40,9 +41,7 @@ class TrezorTest:
def setup_method(self, method):
wirelink = conftest.get_device()
debuglink = wirelink.find_debug()
self.client = TrezorClientDebugLink(wirelink)
self.client.set_debuglink(debuglink)
self.client.set_tx_api(coins.tx_api["Bitcoin"])
# self.client.set_buttonwait(3)
@ -64,6 +63,8 @@ class TrezorTest:
label="test",
language="english",
)
if passphrase:
device.apply_settings(self.client, passphrase_source=PASSPHRASE_ON_HOST)
def setup_mnemonic_allallall(self):
self._setup_mnemonic(mnemonic=TrezorTest.mnemonic_all)

View File

@ -24,12 +24,12 @@ from .common import TrezorTest
@pytest.mark.skip_t2
class TestDebuglink(TrezorTest):
def test_layout(self):
layout = self.client.debug.read_layout()
layout = self.client.debug.state().layout
assert len(layout) == 1024
def test_mnemonic(self):
self.setup_mnemonic_nopin_nopassphrase()
mnemonic = self.client.debug.read_mnemonic()
mnemonic = self.client.debug.state().mnemonic
assert mnemonic == self.mnemonic12
def test_pin(self):
@ -39,9 +39,9 @@ class TestDebuglink(TrezorTest):
resp = self.client.call_raw(proto.Ping(message="test", pin_protection=True))
assert isinstance(resp, proto.PinMatrixRequest)
pin = self.client.debug.read_pin()
assert pin[0] == "1234"
assert pin[1] != ""
pin, matrix = self.client.debug.read_pin()
assert pin == "1234"
assert matrix != ""
pin_encoded = self.client.debug.read_pin_encoded()
resp = self.client.call_raw(proto.PinMatrixAck(pin=pin_encoded))

View File

@ -25,15 +25,10 @@ from .common import TrezorTest
class TestDeviceLoad(TrezorTest):
def test_load_device_1(self):
self.setup_mnemonic_nopin_nopassphrase()
mnemonic = self.client.debug.read_mnemonic()
assert mnemonic == self.mnemonic12
pin = self.client.debug.read_pin()[0]
assert pin is None
passphrase_protection = self.client.debug.read_passphrase_protection()
assert passphrase_protection is False
state = self.client.debug.state()
assert state.mnemonic == self.mnemonic12
assert state.pin is None
assert state.passphrase_protection is False
address = btc.get_address(self.client, "Bitcoin", [])
assert address == "1EfKbQupktEMXf4gujJ9kCFo83k1iMqwqK"
@ -41,15 +36,10 @@ class TestDeviceLoad(TrezorTest):
def test_load_device_2(self):
self.setup_mnemonic_pin_passphrase()
self.client.set_passphrase("passphrase")
mnemonic = self.client.debug.read_mnemonic()
assert mnemonic == self.mnemonic12
pin = self.client.debug.read_pin()[0]
assert pin == self.pin4
passphrase_protection = self.client.debug.read_passphrase_protection()
assert passphrase_protection is True
state = self.client.debug.state()
assert state.mnemonic == self.mnemonic12
assert state.pin == self.pin4
assert state.passphrase_protection is True
address = btc.get_address(self.client, "Bitcoin", [])
assert address == "15fiTDFwZd2kauHYYseifGi9daH2wniDHH"

View File

@ -145,18 +145,35 @@ class TestProtectionLevels(TrezorTest):
device.reset(self.client, False, 128, True, False, "label", "english")
def test_recovery_device(self):
self.client.set_mnemonic(self.mnemonic12)
with self.client:
self.client.set_mnemonic(self.mnemonic12)
self.client.set_expected_responses(
[proto.ButtonRequest()]
+ [proto.WordRequest()] * 24
+ [proto.Success(), proto.Features()]
)
device.recover(self.client, 12, False, False, "label", "english")
device.recover(
self.client,
12,
False,
False,
"label",
"english",
self.client.mnemonic_callback,
)
# This must fail, because device is already initialized
with pytest.raises(Exception):
device.recover(self.client, 12, False, False, "label", "english")
with pytest.raises(RuntimeError):
device.recover(
self.client,
12,
False,
False,
"label",
"english",
self.client.mnemonic_callback,
)
def test_sign_message(self):
with self.client:

View File

@ -22,6 +22,9 @@ import unicodedata
from typing import List, NewType
from .coins import slip44
from .exceptions import TrezorException
CallException = TrezorException
HARDENED_FLAG = 1 << 31
@ -174,10 +177,6 @@ def normalize_nfc(txt):
return unicodedata.normalize("NFC", txt).encode()
class CallException(Exception):
pass
class expect:
# Decorator checks if the method
# returned one of expected protobuf messages

141
trezorlib/ui.py Normal file
View File

@ -0,0 +1,141 @@
import os
import click
from mnemonic import Mnemonic
from .messages import WordRequestType, PinMatrixRequestType
from . import device
from .exceptions import Cancelled
PIN_MATRIX_DESCRIPTION = """
Use the numeric keypad to describe number positions. The layout is:
7 8 9
4 5 6
1 2 3
""".strip()
RECOVERY_MATRIX_DESCRIPTION = """
Use the numeric keypad to describe positions.
For the word list use only left and right keys.
Use backspace to correct an entry.
The keypad layout is:
7 8 9 7 | 9
4 5 6 4 | 6
1 2 3 1 | 3
""".strip()
PIN_GENERIC = None
PIN_CURRENT = PinMatrixRequestType.Current
PIN_NEW = PinMatrixRequestType.NewFirst
PIN_CONFIRM = PinMatrixRequestType.NewSecond
class ClickUI:
@staticmethod
def button_request(code):
click.echo("Please confirm action on your Trezor device")
@staticmethod
def get_pin(code=None):
if code == PIN_CURRENT:
desc = "current PIN"
elif code == PIN_NEW:
desc = "new PIN"
elif code == PIN_CONFIRM:
desc = "new PIN again"
else:
desc = "PIN"
click.echo(PIN_MATRIX_DESCRIPTION)
while True:
pin = click.prompt("Please enter {}".format(desc), hide_input=True)
if not pin.isdigit():
click.echo("Non-numerical PIN provided, please try again")
else:
return pin
@staticmethod
def get_passphrase():
if os.getenv("PASSPHRASE") is not None:
click.echo("Passphrase required. Using PASSPHRASE environment variable.")
return os.getenv("PASSPHRASE")
while True:
passphrase = click.prompt("Passphrase required", hide_input=True)
second = click.prompt("Confirm your passphrase", hide_input=True)
if passphrase == second:
return passphrase
else:
click.echo("Passphrase did not match. Please try again.")
def mnemonic_words(expand=False, language="english"):
if expand:
wordlist = Mnemonic(language).wordlist
else:
wordlist = set()
def expand_word(word):
if not expand:
return word
if word in wordlist:
return word
matches = [w for w in wordlist if w.startswith(word)]
if len(matches) == 1:
return word
click.echo("Choose one of: " + ", ".join(matches))
raise KeyError(word)
def get_word(type):
assert type == WordRequestType.Plain
while True:
try:
word = click.prompt("Enter one word of mnemonic")
return expand_word(word)
except KeyError:
pass
except (KeyboardInterrupt, click.Abort):
raise Cancelled from None
return get_word
try:
# workaround for Click issue https://github.com/pallets/click/pull/1108
import msvcrt
def getchar():
while True:
key = msvcrt.getwch()
if key == "\x03":
raise KeyboardInterrupt
if key in (0x00, 0xe0):
# skip special keys: read the scancode and repeat
msvcrt.getwch()
continue
return key
except ImportError:
getchar = click.getchar
def matrix_words(type):
while True:
try:
ch = getchar()
except (KeyboardInterrupt, EOFError):
raise Cancelled from None
if ch in "\x04\x1b":
# Ctrl+D, Esc
raise Cancelled
if ch in "\x08\x7f":
# Backspace, Del
return device.RECOVERY_BACK
if type == WordRequestType.Matrix6 and ch in "147369":
return ch
if type == WordRequestType.Matrix9 and ch in "123456789":
return ch