1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-13 17:00:59 +00:00

DebugLink click tests (#632)

This commit is contained in:
matejcik 2019-10-23 11:03:52 +02:00 committed by GitHub
commit 09c3fd1981
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 676 additions and 52 deletions

View File

@ -10,7 +10,7 @@ import "messages-common.proto";
/** /**
* Request: "Press" the button on the device * Request: "Press" the button on the device
* @start * @start
* @next Success * @next DebugLinkLayout
*/ */
message DebugLinkDecision { message DebugLinkDecision {
optional bool yes_no = 1; // true for "Confirm", false for "Cancel" optional bool yes_no = 1; // true for "Confirm", false for "Cancel"
@ -25,6 +25,18 @@ message DebugLinkDecision {
LEFT = 2; LEFT = 2;
RIGHT = 3; RIGHT = 3;
} }
optional uint32 x = 4; // touch X coordinate
optional uint32 y = 5; // touch Y coordinate
optional bool wait = 6; // wait for layout change
}
/**
* Response: Device text layout
* @end
*/
message DebugLinkLayout {
repeated string lines = 1;
} }
/** /**
@ -35,6 +47,7 @@ message DebugLinkDecision {
message DebugLinkGetState { message DebugLinkGetState {
optional bool wait_word_list = 1; // Trezor T only - wait until mnemonic words are shown optional bool wait_word_list = 1; // Trezor T only - wait until mnemonic words are shown
optional bool wait_word_pos = 2; // Trezor T only - wait until reset word position is requested optional bool wait_word_pos = 2; // Trezor T only - wait until reset word position is requested
optional bool wait_layout = 3; // wait until current layout changes
} }
/** /**
@ -54,6 +67,7 @@ message DebugLinkState {
optional uint32 recovery_word_pos = 10; // index of mnemonic word the device is expecting during RecoveryDevice workflow optional uint32 recovery_word_pos = 10; // index of mnemonic word the device is expecting during RecoveryDevice workflow
optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow
optional uint32 mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39) optional uint32 mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39)
repeated string layout_lines = 13; // current layout text
} }
/** /**

View File

@ -103,6 +103,7 @@ enum MessageType {
MessageType_DebugLinkMemory = 111 [(wire_debug_out) = true]; MessageType_DebugLinkMemory = 111 [(wire_debug_out) = true];
MessageType_DebugLinkMemoryWrite = 112 [(wire_debug_in) = true]; MessageType_DebugLinkMemoryWrite = 112 [(wire_debug_in) = true];
MessageType_DebugLinkFlashErase = 113 [(wire_debug_in) = true]; MessageType_DebugLinkFlashErase = 113 [(wire_debug_in) = true];
MessageType_DebugLinkLayout = 9001 [(wire_debug_out) = true];
// Ethereum // Ethereum
MessageType_EthereumGetPublicKey = 450 [(wire_in) = true]; MessageType_EthereumGetPublicKey = 450 [(wire_in) = true];

View File

@ -4,12 +4,13 @@ if not __debug__:
halt("debug mode inactive") halt("debug mode inactive")
if __debug__: if __debug__:
from trezor import config, log, loop, utils from trezor import config, io, log, loop, ui, utils
from trezor.messages import MessageType, DebugSwipeDirection from trezor.messages import MessageType, DebugSwipeDirection
from trezor.messages.DebugLinkLayout import DebugLinkLayout
from trezor.wire import register from trezor.wire import register
if False: if False:
from typing import Optional from typing import List, Optional
from trezor import wire from trezor import wire
from trezor.messages.DebugLinkDecision import DebugLinkDecision from trezor.messages.DebugLinkDecision import DebugLinkDecision
from trezor.messages.DebugLinkGetState import DebugLinkGetState from trezor.messages.DebugLinkGetState import DebugLinkGetState
@ -28,6 +29,15 @@ if __debug__:
debuglink_decision_chan = loop.chan() debuglink_decision_chan = loop.chan()
layout_change_chan = loop.chan()
current_content = None # type: Optional[List[str]]
def notify_layout_change(layout: ui.Layout) -> None:
global current_content
current_content = layout.read_content()
if layout_change_chan.takers:
layout_change_chan.publish(current_content)
async def debuglink_decision_dispatcher() -> None: async def debuglink_decision_dispatcher() -> None:
from trezor.ui import confirm, swipe from trezor.ui import confirm, swipe
@ -51,14 +61,26 @@ if __debug__:
loop.schedule(debuglink_decision_dispatcher()) loop.schedule(debuglink_decision_dispatcher())
async def return_layout_change(ctx: wire.Context) -> None:
content = await layout_change_chan.take()
await ctx.write(DebugLinkLayout(lines=content))
async def dispatch_DebugLinkDecision( async def dispatch_DebugLinkDecision(
ctx: wire.Context, msg: DebugLinkDecision ctx: wire.Context, msg: DebugLinkDecision
) -> None: ) -> None:
if debuglink_decision_chan.putters: if debuglink_decision_chan.putters:
log.warning(__name__, "DebugLinkDecision queue is not empty") log.warning(__name__, "DebugLinkDecision queue is not empty")
debuglink_decision_chan.publish(msg) if msg.x is not None:
evt_down = io.TOUCH_START, msg.x, msg.y
evt_up = io.TOUCH_END, msg.x, msg.y
loop.synthetic_events.append((io.TOUCH, evt_down))
loop.synthetic_events.append((io.TOUCH, evt_up))
else:
debuglink_decision_chan.publish(msg)
if msg.wait:
loop.schedule(return_layout_change(ctx))
async def dispatch_DebugLinkGetState( async def dispatch_DebugLinkGetState(
ctx: wire.Context, msg: DebugLinkGetState ctx: wire.Context, msg: DebugLinkGetState
@ -73,6 +95,11 @@ if __debug__:
m.passphrase_protection = has_passphrase() m.passphrase_protection = has_passphrase()
m.reset_entropy = reset_internal_entropy m.reset_entropy = reset_internal_entropy
if msg.wait_layout or current_content is None:
m.layout_lines = await layout_change_chan.take()
else:
m.layout_lines = current_content
if msg.wait_word_pos: if msg.wait_word_pos:
m.reset_word_pos = await reset_word_index.take() m.reset_word_pos = await reset_word_index.take()
if msg.wait_word_list: if msg.wait_word_list:

View File

@ -310,6 +310,11 @@ class RecoveryHomescreen(ui.Component):
self.repaint = False self.repaint = False
if __debug__:
def read_content(self) -> List[str]:
return [self.__class__.__name__, self.text, self.subtext or ""]
async def homescreen_dialog( async def homescreen_dialog(
ctx: wire.GenericContext, ctx: wire.GenericContext,

View File

@ -643,6 +643,11 @@ class MnemonicWordSelect(ui.Layout):
return fn return fn
if __debug__:
def read_content(self):
return self.text.read_content() + [b.text for b in self.buttons]
async def show_reset_device_warning(ctx, backup_type: BackupType = BackupType.Bip39): async def show_reset_device_warning(ctx, backup_type: BackupType = BackupType.Bip39):
text = Text("Create new wallet", ui.ICON_RESET, new_lines=False) text = Text("Create new wallet", ui.ICON_RESET, new_lines=False)

View File

@ -21,6 +21,8 @@ def _boot_recovery() -> None:
# boot applications # boot applications
apps.homescreen.boot(features_only=True) apps.homescreen.boot(features_only=True)
if __debug__:
apps.debug.boot()
from apps.management.recovery_device.homescreen import recovery_homescreen from apps.management.recovery_device.homescreen import recovery_homescreen

View File

@ -50,6 +50,9 @@ if __debug__:
log_delay_rb_len = const(10) log_delay_rb_len = const(10)
log_delay_rb = array.array("i", [0] * log_delay_rb_len) log_delay_rb = array.array("i", [0] * log_delay_rb_len)
# synthetic event queue
synthetic_events = [] # type: List[Tuple[int, Any]]
def schedule( def schedule(
task: Task, value: Any = None, deadline: int = None, finalizer: Finalizer = None task: Task, value: Any = None, deadline: int = None, finalizer: Finalizer = None
@ -125,6 +128,15 @@ def run() -> None:
log_delay_rb[log_delay_pos] = delay log_delay_rb[log_delay_pos] = delay
log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len
# process synthetic events
if synthetic_events:
iface, event = synthetic_events[0]
msg_tasks = _paused.pop(iface, ())
if msg_tasks:
synthetic_events.pop(0)
for task in msg_tasks:
_step(task, event)
if io.poll(_paused, msg_entry, delay): if io.poll(_paused, msg_entry, delay):
# message received, run tasks paused on the interface # message received, run tasks paused on the interface
msg_tasks = _paused.pop(msg_entry[0], ()) msg_tasks = _paused.pop(msg_entry[0], ())

View File

@ -19,10 +19,16 @@ class DebugLinkDecision(p.MessageType):
yes_no: bool = None, yes_no: bool = None,
swipe: EnumTypeDebugSwipeDirection = None, swipe: EnumTypeDebugSwipeDirection = None,
input: str = None, input: str = None,
x: int = None,
y: int = None,
wait: bool = None,
) -> None: ) -> None:
self.yes_no = yes_no self.yes_no = yes_no
self.swipe = swipe self.swipe = swipe
self.input = input self.input = input
self.x = x
self.y = y
self.wait = wait
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls) -> Dict:
@ -30,4 +36,7 @@ class DebugLinkDecision(p.MessageType):
1: ('yes_no', p.BoolType, 0), 1: ('yes_no', p.BoolType, 0),
2: ('swipe', p.EnumType("DebugSwipeDirection", (0, 1, 2, 3)), 0), 2: ('swipe', p.EnumType("DebugSwipeDirection", (0, 1, 2, 3)), 0),
3: ('input', p.UnicodeType, 0), 3: ('input', p.UnicodeType, 0),
4: ('x', p.UVarintType, 0),
5: ('y', p.UVarintType, 0),
6: ('wait', p.BoolType, 0),
} }

View File

@ -17,13 +17,16 @@ class DebugLinkGetState(p.MessageType):
self, self,
wait_word_list: bool = None, wait_word_list: bool = None,
wait_word_pos: bool = None, wait_word_pos: bool = None,
wait_layout: bool = None,
) -> None: ) -> None:
self.wait_word_list = wait_word_list self.wait_word_list = wait_word_list
self.wait_word_pos = wait_word_pos self.wait_word_pos = wait_word_pos
self.wait_layout = wait_layout
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls) -> Dict:
return { return {
1: ('wait_word_list', p.BoolType, 0), 1: ('wait_word_list', p.BoolType, 0),
2: ('wait_word_pos', p.BoolType, 0), 2: ('wait_word_pos', p.BoolType, 0),
3: ('wait_layout', p.BoolType, 0),
} }

View File

@ -0,0 +1,26 @@
# Automatically generated by pb2py
# fmt: off
import protobuf as p
if __debug__:
try:
from typing import Dict, List # noqa: F401
from typing_extensions import Literal # noqa: F401
except ImportError:
pass
class DebugLinkLayout(p.MessageType):
MESSAGE_WIRE_TYPE = 9001
def __init__(
self,
lines: List[str] = None,
) -> None:
self.lines = lines if lines is not None else []
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('lines', p.UnicodeType, p.FLAG_REPEATED),
}

View File

@ -29,6 +29,7 @@ class DebugLinkState(p.MessageType):
recovery_word_pos: int = None, recovery_word_pos: int = None,
reset_word_pos: int = None, reset_word_pos: int = None,
mnemonic_type: int = None, mnemonic_type: int = None,
layout_lines: List[str] = None,
) -> None: ) -> None:
self.layout = layout self.layout = layout
self.pin = pin self.pin = pin
@ -42,6 +43,7 @@ class DebugLinkState(p.MessageType):
self.recovery_word_pos = recovery_word_pos self.recovery_word_pos = recovery_word_pos
self.reset_word_pos = reset_word_pos self.reset_word_pos = reset_word_pos
self.mnemonic_type = mnemonic_type self.mnemonic_type = mnemonic_type
self.layout_lines = layout_lines if layout_lines is not None else []
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls) -> Dict:
@ -58,4 +60,5 @@ class DebugLinkState(p.MessageType):
10: ('recovery_word_pos', p.UVarintType, 0), 10: ('recovery_word_pos', p.UVarintType, 0),
11: ('reset_word_pos', p.UVarintType, 0), 11: ('reset_word_pos', p.UVarintType, 0),
12: ('mnemonic_type', p.UVarintType, 0), 12: ('mnemonic_type', p.UVarintType, 0),
13: ('layout_lines', p.UnicodeType, p.FLAG_REPEATED),
} }

View File

@ -70,6 +70,7 @@ DebugLinkMemoryRead = 110 # type: Literal[110]
DebugLinkMemory = 111 # type: Literal[111] DebugLinkMemory = 111 # type: Literal[111]
DebugLinkMemoryWrite = 112 # type: Literal[112] DebugLinkMemoryWrite = 112 # type: Literal[112]
DebugLinkFlashErase = 113 # type: Literal[113] DebugLinkFlashErase = 113 # type: Literal[113]
DebugLinkLayout = 9001 # type: Literal[9001]
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
EthereumGetPublicKey = 450 # type: Literal[450] EthereumGetPublicKey = 450 # type: Literal[450]
EthereumPublicKey = 451 # type: Literal[451] EthereumPublicKey = 451 # type: Literal[451]

View File

@ -5,8 +5,11 @@ from trezorui import Display
from trezor import io, loop, res, utils from trezor import io, loop, res, utils
if __debug__:
from apps.debug import notify_layout_change
if False: if False:
from typing import Any, Awaitable, Generator, Tuple, TypeVar from typing import Any, Awaitable, Generator, List, Tuple, TypeVar
Pos = Tuple[int, int] Pos = Tuple[int, int]
Area = Tuple[int, int, int, int] Area = Tuple[int, int, int, int]
@ -226,6 +229,11 @@ class Component:
def on_touch_end(self, x: int, y: int) -> None: def on_touch_end(self, x: int, y: int) -> None:
pass pass
if __debug__:
def read_content(self) -> List[str]:
return [self.__class__.__name__]
class Result(Exception): class Result(Exception):
""" """
@ -279,6 +287,8 @@ class Layout(Component):
# layout channel. This allows other layouts to cancel us, and the # layout channel. This allows other layouts to cancel us, and the
# layout tasks to trigger restart by exiting (new tasks are created # layout tasks to trigger restart by exiting (new tasks are created
# and we continue, because we are in a loop). # and we continue, because we are in a loop).
if __debug__:
notify_layout_change(self)
while True: while True:
await loop.race(layout_chan.take(), *self.create_tasks()) await loop.race(layout_chan.take(), *self.create_tasks())
except Result as result: except Result as result:

View File

@ -4,7 +4,7 @@ from trezor import ui
from trezor.ui import display, in_area from trezor.ui import display, in_area
if False: if False:
from typing import Type, Union, Optional from typing import List, Optional, Type, Union
class ButtonDefault: class ButtonDefault:
@ -239,3 +239,8 @@ class Button(ui.Component):
def on_click(self) -> None: def on_click(self) -> None:
pass pass
if __debug__:
def read_content(self) -> List[str]:
return ["<Button: {}>".format(self.text)]

View File

@ -8,7 +8,7 @@ if __debug__:
from apps.debug import swipe_signal from apps.debug import swipe_signal
if False: if False:
from typing import Any, Optional, Tuple from typing import Any, Optional, List, Tuple
from trezor.ui.button import ButtonContent, ButtonStyleType from trezor.ui.button import ButtonContent, ButtonStyleType
from trezor.ui.loader import LoaderStyleType from trezor.ui.loader import LoaderStyleType
@ -74,6 +74,11 @@ class Confirm(ui.Layout):
def on_cancel(self) -> None: def on_cancel(self) -> None:
raise ui.Result(CANCELLED) raise ui.Result(CANCELLED)
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()
class Pageable: class Pageable:
def __init__(self) -> None: def __init__(self) -> None:
@ -201,6 +206,11 @@ class InfoConfirm(ui.Layout):
def on_info(self) -> None: def on_info(self) -> None:
raise ui.Result(INFO) raise ui.Result(INFO)
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()
class HoldToConfirm(ui.Layout): class HoldToConfirm(ui.Layout):
DEFAULT_CONFIRM = "Hold To Confirm" DEFAULT_CONFIRM = "Hold To Confirm"
@ -250,3 +260,8 @@ class HoldToConfirm(ui.Layout):
def on_confirm(self) -> None: def on_confirm(self) -> None:
raise ui.Result(CONFIRMED) raise ui.Result(CONFIRMED)
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()

View File

@ -1,5 +1,8 @@
from trezor import ui from trezor import ui
if False:
from typing import List
class Container(ui.Component): class Container(ui.Component):
def __init__(self, *children: ui.Component): def __init__(self, *children: ui.Component):
@ -8,3 +11,8 @@ class Container(ui.Component):
def dispatch(self, event: int, x: int, y: int) -> None: def dispatch(self, event: int, x: int, y: int) -> None:
for child in self.children: for child in self.children:
child.dispatch(event, x, y) child.dispatch(event, x, y)
if __debug__:
def read_content(self) -> List[str]:
return sum((c.read_content() for c in self.children), [])

View File

@ -6,10 +6,10 @@ from trezor.ui.confirm import CANCELLED, CONFIRMED
from trezor.ui.swipe import SWIPE_DOWN, SWIPE_UP, SWIPE_VERTICAL, Swipe from trezor.ui.swipe import SWIPE_DOWN, SWIPE_UP, SWIPE_VERTICAL, Swipe
if __debug__: if __debug__:
from apps.debug import swipe_signal from apps.debug import swipe_signal, notify_layout_change
if False: if False:
from typing import Tuple, List from typing import List, Tuple
def render_scrollbar(pages: int, page: int) -> None: def render_scrollbar(pages: int, page: int) -> None:
@ -89,6 +89,9 @@ class Paginated(ui.Layout):
self.pages[self.page].dispatch(ui.REPAINT, 0, 0) self.pages[self.page].dispatch(ui.REPAINT, 0, 0)
self.repaint = True self.repaint = True
if __debug__:
notify_layout_change(self)
self.on_change() self.on_change()
def create_tasks(self) -> Tuple[loop.Task, ...]: def create_tasks(self) -> Tuple[loop.Task, ...]:
@ -98,6 +101,11 @@ class Paginated(ui.Layout):
if self.one_by_one: if self.one_by_one:
raise ui.Result(self.page) raise ui.Result(self.page)
if __debug__:
def read_content(self) -> List[str]:
return self.pages[self.page].read_content()
class PageWithButtons(ui.Component): class PageWithButtons(ui.Component):
def __init__( def __init__(
@ -154,6 +162,11 @@ class PageWithButtons(ui.Component):
else: else:
self.paginated.on_down() self.paginated.on_down()
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()
class PaginatedWithButtons(ui.Layout): class PaginatedWithButtons(ui.Layout):
def __init__( def __init__(
@ -191,3 +204,8 @@ class PaginatedWithButtons(ui.Layout):
def on_change(self) -> None: def on_change(self) -> None:
if self.one_by_one: if self.one_by_one:
raise ui.Result(self.page) raise ui.Result(self.page)
if __debug__:
def read_content(self) -> List[str]:
return self.pages[self.page].read_content()

View File

@ -171,6 +171,12 @@ class Text(ui.Component):
render_text(self.content, self.new_lines, self.max_lines) render_text(self.content, self.new_lines, self.max_lines)
self.repaint = False self.repaint = False
if __debug__:
def read_content(self) -> List[str]:
lines = [w for w in self.content if isinstance(w, str)]
return [self.header_text] + lines[: self.max_lines]
LABEL_LEFT = const(0) LABEL_LEFT = const(0)
LABEL_CENTER = const(1) LABEL_CENTER = const(1)
@ -209,6 +215,11 @@ class Label(ui.Component):
) )
self.repaint = False self.repaint = False
if __debug__:
def read_content(self) -> List[str]:
return [self.content]
def text_center_trim_left( def text_center_trim_left(
x: int, y: int, text: str, font: int = ui.NORMAL, width: int = ui.WIDTH - 16 x: int, y: int, text: str, font: int = ui.NORMAL, width: int = ui.WIDTH - 16

View File

@ -40,7 +40,7 @@ from trezor import log, loop, messages, ui, utils, workflow
from trezor.messages import FailureType from trezor.messages import FailureType
from trezor.messages.Failure import Failure from trezor.messages.Failure import Failure
from trezor.wire import codec_v1 from trezor.wire import codec_v1
from trezor.wire.errors import Error from trezor.wire.errors import ActionCancelled, Error
# Import all errors into namespace, so that `wire.Error` is available from # Import all errors into namespace, so that `wire.Error` is available from
# other packages. # other packages.
@ -364,7 +364,10 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
# - the first workflow message was not a valid protobuf # - the first workflow message was not a valid protobuf
# - workflow raised some kind of an exception while running # - workflow raised some kind of an exception while running
if __debug__: if __debug__:
log.exception(__name__, exc) if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: {}".format(exc.message))
else:
log.exception(__name__, exc)
res_msg = failure(exc) res_msg = failure(exc)
finally: finally:

View File

@ -13,3 +13,9 @@ DebugLinkLog.text max_size:256
DebugLinkMemory.memory max_size:1024 DebugLinkMemory.memory max_size:1024
DebugLinkMemoryWrite.memory max_size:1024 DebugLinkMemoryWrite.memory max_size:1024
# unused fields
DebugLinkState.layout_lines max_count:0
DebugLinkState.layout_lines max_size:1
DebugLinkLayout.lines max_size:1
DebugLinkLayout.lines max_count:0

View File

@ -132,9 +132,9 @@ class TrezorClient:
self.session_counter += 1 self.session_counter += 1
def close(self): def close(self):
if self.session_counter == 1: self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.transport.end_session() self.transport.end_session()
self.session_counter -= 1
def cancel(self): def cancel(self):
self._raw_write(messages.Cancel()) self._raw_write(messages.Cancel())

View File

@ -14,6 +14,7 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from collections import namedtuple
from copy import deepcopy from copy import deepcopy
from mnemonic import Mnemonic from mnemonic import Mnemonic
@ -25,6 +26,13 @@ from .tools import expect
EXPECTED_RESPONSES_CONTEXT_LINES = 3 EXPECTED_RESPONSES_CONTEXT_LINES = 3
LayoutLines = namedtuple("LayoutLines", "lines text")
def layout_lines(lines):
return LayoutLines(lines, " ".join(lines))
class DebugLink: class DebugLink:
def __init__(self, transport, auto_interact=True): def __init__(self, transport, auto_interact=True):
self.transport = transport self.transport = transport
@ -46,6 +54,13 @@ class DebugLink:
def state(self): def state(self):
return self._call(proto.DebugLinkGetState()) return self._call(proto.DebugLinkGetState())
def read_layout(self):
return layout_lines(self.state().layout_lines)
def wait_layout(self):
obj = self._call(proto.DebugLinkGetState(wait_layout=True))
return layout_lines(obj.layout_lines)
def read_pin(self): def read_pin(self):
state = self.state() state = self.state()
return state.pin, state.matrix return state.pin, state.matrix
@ -83,16 +98,24 @@ class DebugLink:
obj = self._call(proto.DebugLinkGetState()) obj = self._call(proto.DebugLinkGetState())
return obj.passphrase_protection return obj.passphrase_protection
def input(self, word=None, button=None, swipe=None): def input(self, word=None, button=None, swipe=None, x=None, y=None, wait=False):
if not self.allow_interactions: if not self.allow_interactions:
return return
args = sum(a is not None for a in (word, button, swipe)) args = sum(a is not None for a in (word, button, swipe, x))
if args != 1: if args != 1:
raise ValueError("Invalid input - must use one of word, button, swipe") raise ValueError("Invalid input - must use one of word, button, swipe")
decision = proto.DebugLinkDecision(yes_no=button, swipe=swipe, input=word) decision = proto.DebugLinkDecision(
self._call(decision, nowait=True) yes_no=button, swipe=swipe, input=word, x=x, y=y, wait=wait
)
ret = self._call(decision, nowait=not wait)
if ret is not None:
return layout_lines(ret.lines)
def click(self, click, wait=False):
x, y = click
return self.input(x=x, y=y, wait=wait)
def press_yes(self): def press_yes(self):
self.input(button=True) self.input(button=True)

View File

@ -19,10 +19,16 @@ class DebugLinkDecision(p.MessageType):
yes_no: bool = None, yes_no: bool = None,
swipe: EnumTypeDebugSwipeDirection = None, swipe: EnumTypeDebugSwipeDirection = None,
input: str = None, input: str = None,
x: int = None,
y: int = None,
wait: bool = None,
) -> None: ) -> None:
self.yes_no = yes_no self.yes_no = yes_no
self.swipe = swipe self.swipe = swipe
self.input = input self.input = input
self.x = x
self.y = y
self.wait = wait
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls) -> Dict:
@ -30,4 +36,7 @@ class DebugLinkDecision(p.MessageType):
1: ('yes_no', p.BoolType, 0), 1: ('yes_no', p.BoolType, 0),
2: ('swipe', p.EnumType("DebugSwipeDirection", (0, 1, 2, 3)), 0), 2: ('swipe', p.EnumType("DebugSwipeDirection", (0, 1, 2, 3)), 0),
3: ('input', p.UnicodeType, 0), 3: ('input', p.UnicodeType, 0),
4: ('x', p.UVarintType, 0),
5: ('y', p.UVarintType, 0),
6: ('wait', p.BoolType, 0),
} }

View File

@ -17,13 +17,16 @@ class DebugLinkGetState(p.MessageType):
self, self,
wait_word_list: bool = None, wait_word_list: bool = None,
wait_word_pos: bool = None, wait_word_pos: bool = None,
wait_layout: bool = None,
) -> None: ) -> None:
self.wait_word_list = wait_word_list self.wait_word_list = wait_word_list
self.wait_word_pos = wait_word_pos self.wait_word_pos = wait_word_pos
self.wait_layout = wait_layout
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls) -> Dict:
return { return {
1: ('wait_word_list', p.BoolType, 0), 1: ('wait_word_list', p.BoolType, 0),
2: ('wait_word_pos', p.BoolType, 0), 2: ('wait_word_pos', p.BoolType, 0),
3: ('wait_layout', p.BoolType, 0),
} }

View File

@ -0,0 +1,26 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
if __debug__:
try:
from typing import Dict, List # noqa: F401
from typing_extensions import Literal # noqa: F401
except ImportError:
pass
class DebugLinkLayout(p.MessageType):
MESSAGE_WIRE_TYPE = 9001
def __init__(
self,
lines: List[str] = None,
) -> None:
self.lines = lines if lines is not None else []
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('lines', p.UnicodeType, p.FLAG_REPEATED),
}

View File

@ -29,6 +29,7 @@ class DebugLinkState(p.MessageType):
recovery_word_pos: int = None, recovery_word_pos: int = None,
reset_word_pos: int = None, reset_word_pos: int = None,
mnemonic_type: int = None, mnemonic_type: int = None,
layout_lines: List[str] = None,
) -> None: ) -> None:
self.layout = layout self.layout = layout
self.pin = pin self.pin = pin
@ -42,6 +43,7 @@ class DebugLinkState(p.MessageType):
self.recovery_word_pos = recovery_word_pos self.recovery_word_pos = recovery_word_pos
self.reset_word_pos = reset_word_pos self.reset_word_pos = reset_word_pos
self.mnemonic_type = mnemonic_type self.mnemonic_type = mnemonic_type
self.layout_lines = layout_lines if layout_lines is not None else []
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls) -> Dict:
@ -58,4 +60,5 @@ class DebugLinkState(p.MessageType):
10: ('recovery_word_pos', p.UVarintType, 0), 10: ('recovery_word_pos', p.UVarintType, 0),
11: ('reset_word_pos', p.UVarintType, 0), 11: ('reset_word_pos', p.UVarintType, 0),
12: ('mnemonic_type', p.UVarintType, 0), 12: ('mnemonic_type', p.UVarintType, 0),
13: ('layout_lines', p.UnicodeType, p.FLAG_REPEATED),
} }

View File

@ -68,6 +68,7 @@ DebugLinkMemoryRead = 110 # type: Literal[110]
DebugLinkMemory = 111 # type: Literal[111] DebugLinkMemory = 111 # type: Literal[111]
DebugLinkMemoryWrite = 112 # type: Literal[112] DebugLinkMemoryWrite = 112 # type: Literal[112]
DebugLinkFlashErase = 113 # type: Literal[113] DebugLinkFlashErase = 113 # type: Literal[113]
DebugLinkLayout = 9001 # type: Literal[9001]
EthereumGetPublicKey = 450 # type: Literal[450] EthereumGetPublicKey = 450 # type: Literal[450]
EthereumPublicKey = 451 # type: Literal[451] EthereumPublicKey = 451 # type: Literal[451]
EthereumGetAddress = 56 # type: Literal[56] EthereumGetAddress = 56 # type: Literal[56]

View File

@ -41,6 +41,7 @@ from .CosiSignature import CosiSignature
from .DebugLinkDecision import DebugLinkDecision from .DebugLinkDecision import DebugLinkDecision
from .DebugLinkFlashErase import DebugLinkFlashErase from .DebugLinkFlashErase import DebugLinkFlashErase
from .DebugLinkGetState import DebugLinkGetState from .DebugLinkGetState import DebugLinkGetState
from .DebugLinkLayout import DebugLinkLayout
from .DebugLinkLog import DebugLinkLog from .DebugLinkLog import DebugLinkLog
from .DebugLinkMemory import DebugLinkMemory from .DebugLinkMemory import DebugLinkMemory
from .DebugLinkMemoryRead import DebugLinkMemoryRead from .DebugLinkMemoryRead import DebugLinkMemoryRead

View File

@ -93,9 +93,9 @@ class Protocol:
self.session_counter += 1 self.session_counter += 1
def end_session(self) -> None: def end_session(self) -> None:
if self.session_counter == 1: self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close() self.handle.close()
self.session_counter -= 1
def read(self) -> protobuf.MessageType: def read(self) -> protobuf.MessageType:
raise NotImplementedError raise NotImplementedError

View File

@ -20,6 +20,8 @@ from typing import Iterable, Optional, cast
from . import TransportException from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol from .protocol import ProtocolBasedTransport, get_protocol
SOCKET_TIMEOUT = 10
class UdpTransport(ProtocolBasedTransport): class UdpTransport(ProtocolBasedTransport):
@ -85,7 +87,7 @@ class UdpTransport(ProtocolBasedTransport):
def open(self) -> None: def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device) self.socket.connect(self.device)
self.socket.settimeout(10) self.socket.settimeout(SOCKET_TIMEOUT)
def close(self) -> None: def close(self) -> None:
if self.socket is not None: if self.socket is not None:

44
tests/buttons.py Normal file
View File

@ -0,0 +1,44 @@
DISPLAY_WIDTH = 240
DISPLAY_HEIGHT = 240
def grid(dim, grid_cells, cell):
step = dim // grid_cells
ofs = step // 2
return cell * step + ofs
LEFT = grid(DISPLAY_WIDTH, 3, 0)
MID = grid(DISPLAY_WIDTH, 3, 1)
RIGHT = grid(DISPLAY_WIDTH, 3, 2)
TOP = grid(DISPLAY_HEIGHT, 4, 0)
BOTTOM = grid(DISPLAY_HEIGHT, 4, 3)
OK = (RIGHT, BOTTOM)
CANCEL = (LEFT, BOTTOM)
INFO = (MID, BOTTOM)
CONFIRM_WORD = (MID, TOP)
MINUS = (LEFT, grid(DISPLAY_HEIGHT, 5, 2))
PLUS = (RIGHT, grid(DISPLAY_HEIGHT, 5, 2))
BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
def grid35(x, y):
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y)
def grid34(x, y):
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y)
def type_word(word):
for l in word:
idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters)
grid_x = idx % 3
grid_y = idx // 3 + 1 # first line is empty
yield grid34(grid_x, grid_y)

View File

View File

@ -0,0 +1,76 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import pytest
from trezorlib import device, messages
from .. import buttons
from ..common import MNEMONIC_SLIP39_BASIC_20_3of6
def enter_word(debug, word):
word = word[:4]
for coords in buttons.type_word(word):
debug.click(coords)
return debug.click(buttons.CONFIRM_WORD, wait=True)
@pytest.mark.skip_t1
@pytest.mark.setup_client(uninitialized=True)
def test_recovery(device_handler):
features = device_handler.features()
debug = device_handler.debuglink()
assert features.initialized is False
device_handler.run(device.recover, pin_protection=False)
# select number of words
layout = debug.wait_layout()
assert layout.text.startswith("Recovery mode")
layout = debug.click(buttons.OK, wait=True)
assert "Select number of words" in layout.text
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "WordSelector"
# click "20" at 2, 2
coords = buttons.grid34(2, 2)
lines = debug.click(coords, wait=True)
layout = " ".join(lines)
expected_text = "Enter any share (20 words)"
remaining = len(MNEMONIC_SLIP39_BASIC_20_3of6)
for share in MNEMONIC_SLIP39_BASIC_20_3of6:
assert expected_text in layout.text
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "Slip39Keyboard"
for word in share.split(" "):
layout = enter_word(debug, word)
remaining -= 1
expected_text = "RecoveryHomescreen {} more".format(remaining)
assert "You have successfully recovered your wallet" in layout.text
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "Homescreen"
assert isinstance(device_handler.result(), messages.Success)
features = device_handler.features()
assert features.initialized is True
assert features.recovery_mode is False

View File

@ -24,6 +24,8 @@ from trezorlib.device import apply_settings, wipe as wipe_device
from trezorlib.messages.PassphraseSourceType import HOST as PASSPHRASE_ON_HOST from trezorlib.messages.PassphraseSourceType import HOST as PASSPHRASE_ON_HOST
from trezorlib.transport import enumerate_devices, get_transport from trezorlib.transport import enumerate_devices, get_transport
from .device_handler import BackgroundDeviceHandler
def get_device(): def get_device():
path = os.environ.get("TREZOR_PATH") path = os.environ.get("TREZOR_PATH")
@ -156,3 +158,25 @@ def pytest_runtest_setup(item):
skip_altcoins = int(os.environ.get("TREZOR_PYTEST_SKIP_ALTCOINS", 0)) skip_altcoins = int(os.environ.get("TREZOR_PYTEST_SKIP_ALTCOINS", 0))
if item.get_closest_marker("altcoin") and skip_altcoins: if item.get_closest_marker("altcoin") and skip_altcoins:
pytest.skip("Skipping altcoin test") pytest.skip("Skipping altcoin test")
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
# Make test results available in fixtures.
# See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
# The device_handler fixture uses this as 'request.node.rep_call.passed' attribute,
# in order to raise error only if the test passed.
outcome = yield
rep = outcome.get_result()
setattr(item, f"rep_{rep.when}", rep)
@pytest.fixture
def device_handler(client, request):
device_handler = BackgroundDeviceHandler(client)
yield device_handler
# make sure all background tasks are done
finalized_ok = device_handler.check_finalize()
if request.node.rep_call.passed and not finalized_ok:
raise RuntimeError("Test did not check result of background task")

84
tests/device_handler.py Normal file
View File

@ -0,0 +1,84 @@
from concurrent.futures import ThreadPoolExecutor
from trezorlib.transport import udp
udp.SOCKET_TIMEOUT = 0.1
class NullUI:
@staticmethod
def button_request(code):
pass
@staticmethod
def get_pin(code=None):
raise NotImplementedError("Should not be used with T1")
@staticmethod
def get_passphrase():
raise NotImplementedError("Should not be used with T1")
class BackgroundDeviceHandler:
_pool = ThreadPoolExecutor()
def __init__(self, client):
self.client = client
self.client.ui = NullUI
self.task = None
def run(self, function, *args, **kwargs):
if self.task is not None:
raise RuntimeError("Wait for previous task first")
self.task = self._pool.submit(function, self.client, *args, **kwargs)
def kill_task(self):
if self.task is not None:
# Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have
# a close() method.
while self.client.session_counter > 0:
self.client.close()
try:
self.task.result()
except Exception:
pass
self.task = None
def restart(self, emulator):
# TODO handle actual restart as well
self.kill_task()
emulator.restart()
self.client = emulator.client
self.client.ui = NullUI
def result(self):
if self.task is None:
raise RuntimeError("No task running")
try:
return self.task.result()
finally:
self.task = None
def features(self):
if self.task is not None:
raise RuntimeError("Cannot query features while task is running")
self.client.init_device()
return self.client.features
def debuglink(self):
return self.client.debug
def check_finalize(self):
if self.task is not None:
self.kill_task()
return False
return True
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
finalized_ok = self.check_finalize()
if exc_type is None and not finalized_ok:
raise RuntimeError("Exit while task is unfinished")

View File

@ -14,6 +14,7 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import gzip
import os import os
import subprocess import subprocess
import tempfile import tempfile
@ -21,15 +22,17 @@ import time
from collections import defaultdict from collections import defaultdict
from trezorlib.debuglink import TrezorClientDebugLink from trezorlib.debuglink import TrezorClientDebugLink
from trezorlib.transport import TransportException, get_transport from trezorlib.transport.udp import UdpTransport
BINDIR = os.path.dirname(os.path.abspath(__file__)) + "/emulators" ROOT = os.path.abspath(os.path.dirname(__file__) + "/..")
ROOT = os.path.dirname(os.path.abspath(__file__)) + "/../" BINDIR = ROOT + "/tests/emulators"
LOCAL_BUILD_PATHS = { LOCAL_BUILD_PATHS = {
"core": ROOT + "core/build/unix/micropython", "core": ROOT + "/core/build/unix/micropython",
"legacy": ROOT + "legacy/firmware/trezor.elf", "legacy": ROOT + "/legacy/firmware/trezor.elf",
} }
SD_CARD_GZ = ROOT + "/tests/trezor.sdcard.gz"
ENV = {"SDL_VIDEODRIVER": "dummy"} ENV = {"SDL_VIDEODRIVER": "dummy"}
@ -84,46 +87,68 @@ class EmulatorWrapper:
if storage: if storage:
open(self._storage_file(), "wb").write(storage) open(self._storage_file(), "wb").write(storage)
with gzip.open(SD_CARD_GZ, "rb") as gz:
with open(self.workdir.name + "/trezor.sdcard", "wb") as sd:
sd.write(gz.read())
self.client = None self.client = None
def __enter__(self): def _get_params_core(self):
env = ENV.copy()
args = [self.executable, "-m", "main"]
# for firmware 2.1.2 and newer
env["TREZOR_PROFILE_DIR"] = self.workdir.name
# for firmware 2.1.1 and older
env["TREZOR_PROFILE"] = self.workdir.name
if self.executable == LOCAL_BUILD_PATHS["core"]:
cwd = ROOT + "/core/src"
else:
cwd = self.workdir.name
return env, args, cwd
def _get_params_legacy(self):
env = ENV.copy()
args = [self.executable] args = [self.executable]
env = ENV cwd = self.workdir.name
return env, args, cwd
def _get_params(self):
if self.gen == "core": if self.gen == "core":
args += ["-m", "main"] return self._get_params_core()
# for firmware 2.1.2 and newer elif self.gen == "legacy":
env["TREZOR_PROFILE_DIR"] = self.workdir.name return self._get_params_legacy()
# for firmware 2.1.1 and older else:
env["TREZOR_PROFILE"] = self.workdir.name raise ValueError("Unknown gen")
def start(self):
env, args, cwd = self._get_params()
self.process = subprocess.Popen( self.process = subprocess.Popen(
args, cwd=self.workdir.name, env=env, stdout=open(os.devnull, "w") args, cwd=cwd, env=env, stdout=open(os.devnull, "w")
) )
# wait until emulator is listening # wait until emulator is listening
transport = UdpTransport("127.0.0.1:21324")
transport.open()
for _ in range(300): for _ in range(300):
try: if transport._ping():
time.sleep(0.1)
transport = get_transport("udp:127.0.0.1:21324")
break break
except TransportException:
pass
if self.process.poll() is not None: if self.process.poll() is not None:
self._cleanup() self._cleanup()
raise RuntimeError("Emulator proces died") raise RuntimeError("Emulator proces died")
time.sleep(0.1)
else: else:
# could not connect after 300 attempts * 0.1s = 30s of waiting # could not connect after 300 attempts * 0.1s = 30s of waiting
self._cleanup() self._cleanup()
raise RuntimeError("Can't connect to emulator") raise RuntimeError("Can't connect to emulator")
transport.close()
self.client = TrezorClientDebugLink(transport) self.client = TrezorClientDebugLink(transport)
self.client.open() self.client.open()
check_version(self.tag, self.client.version) check_version(self.tag, self.client.version)
return self
def __exit__(self, exc_type, exc_value, traceback): def stop(self):
self._cleanup()
return False
def _cleanup(self):
if self.client: if self.client:
self.client.close() self.client.close()
self.process.terminate() self.process.terminate()
@ -131,6 +156,20 @@ class EmulatorWrapper:
self.process.wait(1) self.process.wait(1)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
self.process.kill() self.process.kill()
def restart(self):
self.stop()
self.start()
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self._cleanup()
def _cleanup(self):
self.stop()
self.workdir.cleanup() self.workdir.cleanup()
def _storage_file(self): def _storage_file(self):

BIN
tests/trezor.sdcard.gz Normal file

Binary file not shown.

View File

@ -0,0 +1,37 @@
import os
import pytest
from ..emulators import EmulatorWrapper
SELECTED_GENS = [
gen.strip() for gen in os.environ.get("TREZOR_UPGRADE_TEST", "").split(",") if gen
]
if SELECTED_GENS:
# if any gens were selected via the environment variable, force enable all selected
LEGACY_ENABLED = "legacy" in SELECTED_GENS
CORE_ENABLED = "core" in SELECTED_GENS
else:
# if no selection was provided, select those for which we have emulators
try:
EmulatorWrapper("legacy")
LEGACY_ENABLED = True
except Exception:
LEGACY_ENABLED = False
try:
EmulatorWrapper("core")
CORE_ENABLED = True
except Exception:
CORE_ENABLED = False
legacy_only = pytest.mark.skipif(
not LEGACY_ENABLED, reason="This test requires legacy emulator"
)
core_only = pytest.mark.skipif(
not CORE_ENABLED, reason="This test requires core emulator"
)

View File

@ -14,14 +14,13 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os
import pytest import pytest
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device
from trezorlib.tools import H_ from trezorlib.tools import H_
from ..emulators import ALL_TAGS, EmulatorWrapper from ..emulators import ALL_TAGS, EmulatorWrapper
from . import SELECTED_GENS
MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0) MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0)
MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0) MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0)
@ -41,11 +40,8 @@ def for_all(*args, minimum_version=(1, 0, 0)):
if not args: if not args:
args = ("core", "legacy") args = ("core", "legacy")
specified_gens = os.environ.get("TREZOR_UPGRADE_TEST") # If any gens were selected, use them. If none, select all.
if specified_gens is not None: enabled_gens = SELECTED_GENS or args
enabled_gens = specified_gens.split(",")
else:
enabled_gens = args
all_params = [] all_params = []
for gen in args: for gen in args:

View File

@ -0,0 +1,72 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import pytest
from trezorlib import device
from .. import buttons
from ..device_handler import BackgroundDeviceHandler
from ..emulators import EmulatorWrapper
from . import core_only
def enter_word(debug, word):
word = word[:4]
for coords in buttons.type_word(word):
debug.click(coords)
return debug.click(buttons.CONFIRM_WORD, wait=True)
@pytest.fixture
def emulator():
emu = EmulatorWrapper("core")
with emu:
yield emu
@core_only
def test_persistence(emulator):
device_handler = BackgroundDeviceHandler(emulator.client)
debug = device_handler.debuglink()
features = device_handler.features()
assert features.recovery_mode is False
device_handler.run(device.recover, pin_protection=False)
layout = debug.wait_layout()
assert layout.text.startswith("Recovery mode")
layout = debug.click(buttons.OK, wait=True)
assert "Select number of words" in layout.text
device_handler.restart(emulator)
debug = device_handler.debuglink()
features = device_handler.features()
assert features.recovery_mode is True
# no waiting for layout because layout doesn't change
layout = debug.read_layout()
assert "Select number of words" in layout.text
layout = debug.click(buttons.CANCEL, wait=True)
assert layout.text.startswith("Abort recovery")
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "Homescreen"
features = device_handler.features()
assert features.recovery_mode is False