DebugLink click tests (#632)

pull/645/head
matejcik 5 years ago committed by GitHub
commit 09c3fd1981
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,7 +10,7 @@ import "messages-common.proto";
/**
* Request: "Press" the button on the device
* @start
* @next Success
* @next DebugLinkLayout
*/
message DebugLinkDecision {
optional bool yes_no = 1; // true for "Confirm", false for "Cancel"
@ -25,6 +25,18 @@ message DebugLinkDecision {
LEFT = 2;
RIGHT = 3;
}
optional uint32 x = 4; // touch X coordinate
optional uint32 y = 5; // touch Y coordinate
optional bool wait = 6; // wait for layout change
}
/**
* Response: Device text layout
* @end
*/
message DebugLinkLayout {
repeated string lines = 1;
}
/**
@ -35,6 +47,7 @@ message DebugLinkDecision {
message DebugLinkGetState {
optional bool wait_word_list = 1; // Trezor T only - wait until mnemonic words are shown
optional bool wait_word_pos = 2; // Trezor T only - wait until reset word position is requested
optional bool wait_layout = 3; // wait until current layout changes
}
/**
@ -54,6 +67,7 @@ message DebugLinkState {
optional uint32 recovery_word_pos = 10; // index of mnemonic word the device is expecting during RecoveryDevice workflow
optional uint32 reset_word_pos = 11; // index of mnemonic word the device is expecting during ResetDevice workflow
optional uint32 mnemonic_type = 12; // current mnemonic type (BIP-39/SLIP-39)
repeated string layout_lines = 13; // current layout text
}
/**

@ -103,6 +103,7 @@ enum MessageType {
MessageType_DebugLinkMemory = 111 [(wire_debug_out) = true];
MessageType_DebugLinkMemoryWrite = 112 [(wire_debug_in) = true];
MessageType_DebugLinkFlashErase = 113 [(wire_debug_in) = true];
MessageType_DebugLinkLayout = 9001 [(wire_debug_out) = true];
// Ethereum
MessageType_EthereumGetPublicKey = 450 [(wire_in) = true];

@ -4,12 +4,13 @@ if not __debug__:
halt("debug mode inactive")
if __debug__:
from trezor import config, log, loop, utils
from trezor import config, io, log, loop, ui, utils
from trezor.messages import MessageType, DebugSwipeDirection
from trezor.messages.DebugLinkLayout import DebugLinkLayout
from trezor.wire import register
if False:
from typing import Optional
from typing import List, Optional
from trezor import wire
from trezor.messages.DebugLinkDecision import DebugLinkDecision
from trezor.messages.DebugLinkGetState import DebugLinkGetState
@ -28,6 +29,15 @@ if __debug__:
debuglink_decision_chan = loop.chan()
layout_change_chan = loop.chan()
current_content = None # type: Optional[List[str]]
def notify_layout_change(layout: ui.Layout) -> None:
global current_content
current_content = layout.read_content()
if layout_change_chan.takers:
layout_change_chan.publish(current_content)
async def debuglink_decision_dispatcher() -> None:
from trezor.ui import confirm, swipe
@ -51,14 +61,26 @@ if __debug__:
loop.schedule(debuglink_decision_dispatcher())
async def return_layout_change(ctx: wire.Context) -> None:
content = await layout_change_chan.take()
await ctx.write(DebugLinkLayout(lines=content))
async def dispatch_DebugLinkDecision(
ctx: wire.Context, msg: DebugLinkDecision
) -> None:
if debuglink_decision_chan.putters:
log.warning(__name__, "DebugLinkDecision queue is not empty")
debuglink_decision_chan.publish(msg)
if msg.x is not None:
evt_down = io.TOUCH_START, msg.x, msg.y
evt_up = io.TOUCH_END, msg.x, msg.y
loop.synthetic_events.append((io.TOUCH, evt_down))
loop.synthetic_events.append((io.TOUCH, evt_up))
else:
debuglink_decision_chan.publish(msg)
if msg.wait:
loop.schedule(return_layout_change(ctx))
async def dispatch_DebugLinkGetState(
ctx: wire.Context, msg: DebugLinkGetState
@ -73,6 +95,11 @@ if __debug__:
m.passphrase_protection = has_passphrase()
m.reset_entropy = reset_internal_entropy
if msg.wait_layout or current_content is None:
m.layout_lines = await layout_change_chan.take()
else:
m.layout_lines = current_content
if msg.wait_word_pos:
m.reset_word_pos = await reset_word_index.take()
if msg.wait_word_list:

@ -310,6 +310,11 @@ class RecoveryHomescreen(ui.Component):
self.repaint = False
if __debug__:
def read_content(self) -> List[str]:
return [self.__class__.__name__, self.text, self.subtext or ""]
async def homescreen_dialog(
ctx: wire.GenericContext,

@ -643,6 +643,11 @@ class MnemonicWordSelect(ui.Layout):
return fn
if __debug__:
def read_content(self):
return self.text.read_content() + [b.text for b in self.buttons]
async def show_reset_device_warning(ctx, backup_type: BackupType = BackupType.Bip39):
text = Text("Create new wallet", ui.ICON_RESET, new_lines=False)

@ -21,6 +21,8 @@ def _boot_recovery() -> None:
# boot applications
apps.homescreen.boot(features_only=True)
if __debug__:
apps.debug.boot()
from apps.management.recovery_device.homescreen import recovery_homescreen

@ -50,6 +50,9 @@ if __debug__:
log_delay_rb_len = const(10)
log_delay_rb = array.array("i", [0] * log_delay_rb_len)
# synthetic event queue
synthetic_events = [] # type: List[Tuple[int, Any]]
def schedule(
task: Task, value: Any = None, deadline: int = None, finalizer: Finalizer = None
@ -125,6 +128,15 @@ def run() -> None:
log_delay_rb[log_delay_pos] = delay
log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len
# process synthetic events
if synthetic_events:
iface, event = synthetic_events[0]
msg_tasks = _paused.pop(iface, ())
if msg_tasks:
synthetic_events.pop(0)
for task in msg_tasks:
_step(task, event)
if io.poll(_paused, msg_entry, delay):
# message received, run tasks paused on the interface
msg_tasks = _paused.pop(msg_entry[0], ())

@ -19,10 +19,16 @@ class DebugLinkDecision(p.MessageType):
yes_no: bool = None,
swipe: EnumTypeDebugSwipeDirection = None,
input: str = None,
x: int = None,
y: int = None,
wait: bool = None,
) -> None:
self.yes_no = yes_no
self.swipe = swipe
self.input = input
self.x = x
self.y = y
self.wait = wait
@classmethod
def get_fields(cls) -> Dict:
@ -30,4 +36,7 @@ class DebugLinkDecision(p.MessageType):
1: ('yes_no', p.BoolType, 0),
2: ('swipe', p.EnumType("DebugSwipeDirection", (0, 1, 2, 3)), 0),
3: ('input', p.UnicodeType, 0),
4: ('x', p.UVarintType, 0),
5: ('y', p.UVarintType, 0),
6: ('wait', p.BoolType, 0),
}

@ -17,13 +17,16 @@ class DebugLinkGetState(p.MessageType):
self,
wait_word_list: bool = None,
wait_word_pos: bool = None,
wait_layout: bool = None,
) -> None:
self.wait_word_list = wait_word_list
self.wait_word_pos = wait_word_pos
self.wait_layout = wait_layout
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('wait_word_list', p.BoolType, 0),
2: ('wait_word_pos', p.BoolType, 0),
3: ('wait_layout', p.BoolType, 0),
}

@ -0,0 +1,26 @@
# Automatically generated by pb2py
# fmt: off
import protobuf as p
if __debug__:
try:
from typing import Dict, List # noqa: F401
from typing_extensions import Literal # noqa: F401
except ImportError:
pass
class DebugLinkLayout(p.MessageType):
MESSAGE_WIRE_TYPE = 9001
def __init__(
self,
lines: List[str] = None,
) -> None:
self.lines = lines if lines is not None else []
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('lines', p.UnicodeType, p.FLAG_REPEATED),
}

@ -29,6 +29,7 @@ class DebugLinkState(p.MessageType):
recovery_word_pos: int = None,
reset_word_pos: int = None,
mnemonic_type: int = None,
layout_lines: List[str] = None,
) -> None:
self.layout = layout
self.pin = pin
@ -42,6 +43,7 @@ class DebugLinkState(p.MessageType):
self.recovery_word_pos = recovery_word_pos
self.reset_word_pos = reset_word_pos
self.mnemonic_type = mnemonic_type
self.layout_lines = layout_lines if layout_lines is not None else []
@classmethod
def get_fields(cls) -> Dict:
@ -58,4 +60,5 @@ class DebugLinkState(p.MessageType):
10: ('recovery_word_pos', p.UVarintType, 0),
11: ('reset_word_pos', p.UVarintType, 0),
12: ('mnemonic_type', p.UVarintType, 0),
13: ('layout_lines', p.UnicodeType, p.FLAG_REPEATED),
}

@ -70,6 +70,7 @@ DebugLinkMemoryRead = 110 # type: Literal[110]
DebugLinkMemory = 111 # type: Literal[111]
DebugLinkMemoryWrite = 112 # type: Literal[112]
DebugLinkFlashErase = 113 # type: Literal[113]
DebugLinkLayout = 9001 # type: Literal[9001]
if not utils.BITCOIN_ONLY:
EthereumGetPublicKey = 450 # type: Literal[450]
EthereumPublicKey = 451 # type: Literal[451]

@ -5,8 +5,11 @@ from trezorui import Display
from trezor import io, loop, res, utils
if __debug__:
from apps.debug import notify_layout_change
if False:
from typing import Any, Awaitable, Generator, Tuple, TypeVar
from typing import Any, Awaitable, Generator, List, Tuple, TypeVar
Pos = Tuple[int, int]
Area = Tuple[int, int, int, int]
@ -226,6 +229,11 @@ class Component:
def on_touch_end(self, x: int, y: int) -> None:
pass
if __debug__:
def read_content(self) -> List[str]:
return [self.__class__.__name__]
class Result(Exception):
"""
@ -279,6 +287,8 @@ class Layout(Component):
# layout channel. This allows other layouts to cancel us, and the
# layout tasks to trigger restart by exiting (new tasks are created
# and we continue, because we are in a loop).
if __debug__:
notify_layout_change(self)
while True:
await loop.race(layout_chan.take(), *self.create_tasks())
except Result as result:

@ -4,7 +4,7 @@ from trezor import ui
from trezor.ui import display, in_area
if False:
from typing import Type, Union, Optional
from typing import List, Optional, Type, Union
class ButtonDefault:
@ -239,3 +239,8 @@ class Button(ui.Component):
def on_click(self) -> None:
pass
if __debug__:
def read_content(self) -> List[str]:
return ["<Button: {}>".format(self.text)]

@ -8,7 +8,7 @@ if __debug__:
from apps.debug import swipe_signal
if False:
from typing import Any, Optional, Tuple
from typing import Any, Optional, List, Tuple
from trezor.ui.button import ButtonContent, ButtonStyleType
from trezor.ui.loader import LoaderStyleType
@ -74,6 +74,11 @@ class Confirm(ui.Layout):
def on_cancel(self) -> None:
raise ui.Result(CANCELLED)
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()
class Pageable:
def __init__(self) -> None:
@ -201,6 +206,11 @@ class InfoConfirm(ui.Layout):
def on_info(self) -> None:
raise ui.Result(INFO)
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()
class HoldToConfirm(ui.Layout):
DEFAULT_CONFIRM = "Hold To Confirm"
@ -250,3 +260,8 @@ class HoldToConfirm(ui.Layout):
def on_confirm(self) -> None:
raise ui.Result(CONFIRMED)
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()

@ -1,5 +1,8 @@
from trezor import ui
if False:
from typing import List
class Container(ui.Component):
def __init__(self, *children: ui.Component):
@ -8,3 +11,8 @@ class Container(ui.Component):
def dispatch(self, event: int, x: int, y: int) -> None:
for child in self.children:
child.dispatch(event, x, y)
if __debug__:
def read_content(self) -> List[str]:
return sum((c.read_content() for c in self.children), [])

@ -6,10 +6,10 @@ from trezor.ui.confirm import CANCELLED, CONFIRMED
from trezor.ui.swipe import SWIPE_DOWN, SWIPE_UP, SWIPE_VERTICAL, Swipe
if __debug__:
from apps.debug import swipe_signal
from apps.debug import swipe_signal, notify_layout_change
if False:
from typing import Tuple, List
from typing import List, Tuple
def render_scrollbar(pages: int, page: int) -> None:
@ -89,6 +89,9 @@ class Paginated(ui.Layout):
self.pages[self.page].dispatch(ui.REPAINT, 0, 0)
self.repaint = True
if __debug__:
notify_layout_change(self)
self.on_change()
def create_tasks(self) -> Tuple[loop.Task, ...]:
@ -98,6 +101,11 @@ class Paginated(ui.Layout):
if self.one_by_one:
raise ui.Result(self.page)
if __debug__:
def read_content(self) -> List[str]:
return self.pages[self.page].read_content()
class PageWithButtons(ui.Component):
def __init__(
@ -154,6 +162,11 @@ class PageWithButtons(ui.Component):
else:
self.paginated.on_down()
if __debug__:
def read_content(self) -> List[str]:
return self.content.read_content()
class PaginatedWithButtons(ui.Layout):
def __init__(
@ -191,3 +204,8 @@ class PaginatedWithButtons(ui.Layout):
def on_change(self) -> None:
if self.one_by_one:
raise ui.Result(self.page)
if __debug__:
def read_content(self) -> List[str]:
return self.pages[self.page].read_content()

@ -171,6 +171,12 @@ class Text(ui.Component):
render_text(self.content, self.new_lines, self.max_lines)
self.repaint = False
if __debug__:
def read_content(self) -> List[str]:
lines = [w for w in self.content if isinstance(w, str)]
return [self.header_text] + lines[: self.max_lines]
LABEL_LEFT = const(0)
LABEL_CENTER = const(1)
@ -209,6 +215,11 @@ class Label(ui.Component):
)
self.repaint = False
if __debug__:
def read_content(self) -> List[str]:
return [self.content]
def text_center_trim_left(
x: int, y: int, text: str, font: int = ui.NORMAL, width: int = ui.WIDTH - 16

@ -40,7 +40,7 @@ from trezor import log, loop, messages, ui, utils, workflow
from trezor.messages import FailureType
from trezor.messages.Failure import Failure
from trezor.wire import codec_v1
from trezor.wire.errors import Error
from trezor.wire.errors import ActionCancelled, Error
# Import all errors into namespace, so that `wire.Error` is available from
# other packages.
@ -364,7 +364,10 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
# - the first workflow message was not a valid protobuf
# - workflow raised some kind of an exception while running
if __debug__:
log.exception(__name__, exc)
if isinstance(exc, ActionCancelled):
log.debug(__name__, "cancelled: {}".format(exc.message))
else:
log.exception(__name__, exc)
res_msg = failure(exc)
finally:

@ -13,3 +13,9 @@ DebugLinkLog.text max_size:256
DebugLinkMemory.memory max_size:1024
DebugLinkMemoryWrite.memory max_size:1024
# unused fields
DebugLinkState.layout_lines max_count:0
DebugLinkState.layout_lines max_size:1
DebugLinkLayout.lines max_size:1
DebugLinkLayout.lines max_count:0

@ -132,9 +132,9 @@ class TrezorClient:
self.session_counter += 1
def close(self):
if self.session_counter == 1:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.transport.end_session()
self.session_counter -= 1
def cancel(self):
self._raw_write(messages.Cancel())

@ -14,6 +14,7 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from collections import namedtuple
from copy import deepcopy
from mnemonic import Mnemonic
@ -25,6 +26,13 @@ from .tools import expect
EXPECTED_RESPONSES_CONTEXT_LINES = 3
LayoutLines = namedtuple("LayoutLines", "lines text")
def layout_lines(lines):
return LayoutLines(lines, " ".join(lines))
class DebugLink:
def __init__(self, transport, auto_interact=True):
self.transport = transport
@ -46,6 +54,13 @@ class DebugLink:
def state(self):
return self._call(proto.DebugLinkGetState())
def read_layout(self):
return layout_lines(self.state().layout_lines)
def wait_layout(self):
obj = self._call(proto.DebugLinkGetState(wait_layout=True))
return layout_lines(obj.layout_lines)
def read_pin(self):
state = self.state()
return state.pin, state.matrix
@ -83,16 +98,24 @@ class DebugLink:
obj = self._call(proto.DebugLinkGetState())
return obj.passphrase_protection
def input(self, word=None, button=None, swipe=None):
def input(self, word=None, button=None, swipe=None, x=None, y=None, wait=False):
if not self.allow_interactions:
return
args = sum(a is not None for a in (word, button, swipe))
args = sum(a is not None for a in (word, button, swipe, x))
if args != 1:
raise ValueError("Invalid input - must use one of word, button, swipe")
decision = proto.DebugLinkDecision(yes_no=button, swipe=swipe, input=word)
self._call(decision, nowait=True)
decision = proto.DebugLinkDecision(
yes_no=button, swipe=swipe, input=word, x=x, y=y, wait=wait
)
ret = self._call(decision, nowait=not wait)
if ret is not None:
return layout_lines(ret.lines)
def click(self, click, wait=False):
x, y = click
return self.input(x=x, y=y, wait=wait)
def press_yes(self):
self.input(button=True)

@ -19,10 +19,16 @@ class DebugLinkDecision(p.MessageType):
yes_no: bool = None,
swipe: EnumTypeDebugSwipeDirection = None,
input: str = None,
x: int = None,
y: int = None,
wait: bool = None,
) -> None:
self.yes_no = yes_no
self.swipe = swipe
self.input = input
self.x = x
self.y = y
self.wait = wait
@classmethod
def get_fields(cls) -> Dict:
@ -30,4 +36,7 @@ class DebugLinkDecision(p.MessageType):
1: ('yes_no', p.BoolType, 0),
2: ('swipe', p.EnumType("DebugSwipeDirection", (0, 1, 2, 3)), 0),
3: ('input', p.UnicodeType, 0),
4: ('x', p.UVarintType, 0),
5: ('y', p.UVarintType, 0),
6: ('wait', p.BoolType, 0),
}

@ -17,13 +17,16 @@ class DebugLinkGetState(p.MessageType):
self,
wait_word_list: bool = None,
wait_word_pos: bool = None,
wait_layout: bool = None,
) -> None:
self.wait_word_list = wait_word_list
self.wait_word_pos = wait_word_pos
self.wait_layout = wait_layout
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('wait_word_list', p.BoolType, 0),
2: ('wait_word_pos', p.BoolType, 0),
3: ('wait_layout', p.BoolType, 0),
}

@ -0,0 +1,26 @@
# Automatically generated by pb2py
# fmt: off
from .. import protobuf as p
if __debug__:
try:
from typing import Dict, List # noqa: F401
from typing_extensions import Literal # noqa: F401
except ImportError:
pass
class DebugLinkLayout(p.MessageType):
MESSAGE_WIRE_TYPE = 9001
def __init__(
self,
lines: List[str] = None,
) -> None:
self.lines = lines if lines is not None else []
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('lines', p.UnicodeType, p.FLAG_REPEATED),
}

@ -29,6 +29,7 @@ class DebugLinkState(p.MessageType):
recovery_word_pos: int = None,
reset_word_pos: int = None,
mnemonic_type: int = None,
layout_lines: List[str] = None,
) -> None:
self.layout = layout
self.pin = pin
@ -42,6 +43,7 @@ class DebugLinkState(p.MessageType):
self.recovery_word_pos = recovery_word_pos
self.reset_word_pos = reset_word_pos
self.mnemonic_type = mnemonic_type
self.layout_lines = layout_lines if layout_lines is not None else []
@classmethod
def get_fields(cls) -> Dict:
@ -58,4 +60,5 @@ class DebugLinkState(p.MessageType):
10: ('recovery_word_pos', p.UVarintType, 0),
11: ('reset_word_pos', p.UVarintType, 0),
12: ('mnemonic_type', p.UVarintType, 0),
13: ('layout_lines', p.UnicodeType, p.FLAG_REPEATED),
}

@ -68,6 +68,7 @@ DebugLinkMemoryRead = 110 # type: Literal[110]
DebugLinkMemory = 111 # type: Literal[111]
DebugLinkMemoryWrite = 112 # type: Literal[112]
DebugLinkFlashErase = 113 # type: Literal[113]
DebugLinkLayout = 9001 # type: Literal[9001]
EthereumGetPublicKey = 450 # type: Literal[450]
EthereumPublicKey = 451 # type: Literal[451]
EthereumGetAddress = 56 # type: Literal[56]

@ -41,6 +41,7 @@ from .CosiSignature import CosiSignature
from .DebugLinkDecision import DebugLinkDecision
from .DebugLinkFlashErase import DebugLinkFlashErase
from .DebugLinkGetState import DebugLinkGetState
from .DebugLinkLayout import DebugLinkLayout
from .DebugLinkLog import DebugLinkLog
from .DebugLinkMemory import DebugLinkMemory
from .DebugLinkMemoryRead import DebugLinkMemoryRead

@ -93,9 +93,9 @@ class Protocol:
self.session_counter += 1
def end_session(self) -> None:
if self.session_counter == 1:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
self.session_counter -= 1
def read(self) -> protobuf.MessageType:
raise NotImplementedError

@ -20,6 +20,8 @@ from typing import Iterable, Optional, cast
from . import TransportException
from .protocol import ProtocolBasedTransport, get_protocol
SOCKET_TIMEOUT = 10
class UdpTransport(ProtocolBasedTransport):
@ -85,7 +87,7 @@ class UdpTransport(ProtocolBasedTransport):
def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.connect(self.device)
self.socket.settimeout(10)
self.socket.settimeout(SOCKET_TIMEOUT)
def close(self) -> None:
if self.socket is not None:

@ -0,0 +1,44 @@
DISPLAY_WIDTH = 240
DISPLAY_HEIGHT = 240
def grid(dim, grid_cells, cell):
step = dim // grid_cells
ofs = step // 2
return cell * step + ofs
LEFT = grid(DISPLAY_WIDTH, 3, 0)
MID = grid(DISPLAY_WIDTH, 3, 1)
RIGHT = grid(DISPLAY_WIDTH, 3, 2)
TOP = grid(DISPLAY_HEIGHT, 4, 0)
BOTTOM = grid(DISPLAY_HEIGHT, 4, 3)
OK = (RIGHT, BOTTOM)
CANCEL = (LEFT, BOTTOM)
INFO = (MID, BOTTOM)
CONFIRM_WORD = (MID, TOP)
MINUS = (LEFT, grid(DISPLAY_HEIGHT, 5, 2))
PLUS = (RIGHT, grid(DISPLAY_HEIGHT, 5, 2))
BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
def grid35(x, y):
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y)
def grid34(x, y):
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y)
def type_word(word):
for l in word:
idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters)
grid_x = idx % 3
grid_y = idx // 3 + 1 # first line is empty
yield grid34(grid_x, grid_y)

@ -0,0 +1,76 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import pytest
from trezorlib import device, messages
from .. import buttons
from ..common import MNEMONIC_SLIP39_BASIC_20_3of6
def enter_word(debug, word):
word = word[:4]
for coords in buttons.type_word(word):
debug.click(coords)
return debug.click(buttons.CONFIRM_WORD, wait=True)
@pytest.mark.skip_t1
@pytest.mark.setup_client(uninitialized=True)
def test_recovery(device_handler):
features = device_handler.features()
debug = device_handler.debuglink()
assert features.initialized is False
device_handler.run(device.recover, pin_protection=False)
# select number of words
layout = debug.wait_layout()
assert layout.text.startswith("Recovery mode")
layout = debug.click(buttons.OK, wait=True)
assert "Select number of words" in layout.text
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "WordSelector"
# click "20" at 2, 2
coords = buttons.grid34(2, 2)
lines = debug.click(coords, wait=True)
layout = " ".join(lines)
expected_text = "Enter any share (20 words)"
remaining = len(MNEMONIC_SLIP39_BASIC_20_3of6)
for share in MNEMONIC_SLIP39_BASIC_20_3of6:
assert expected_text in layout.text
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "Slip39Keyboard"
for word in share.split(" "):
layout = enter_word(debug, word)
remaining -= 1
expected_text = "RecoveryHomescreen {} more".format(remaining)
assert "You have successfully recovered your wallet" in layout.text
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "Homescreen"
assert isinstance(device_handler.result(), messages.Success)
features = device_handler.features()
assert features.initialized is True
assert features.recovery_mode is False

@ -24,6 +24,8 @@ from trezorlib.device import apply_settings, wipe as wipe_device
from trezorlib.messages.PassphraseSourceType import HOST as PASSPHRASE_ON_HOST
from trezorlib.transport import enumerate_devices, get_transport
from .device_handler import BackgroundDeviceHandler
def get_device():
path = os.environ.get("TREZOR_PATH")
@ -156,3 +158,25 @@ def pytest_runtest_setup(item):
skip_altcoins = int(os.environ.get("TREZOR_PYTEST_SKIP_ALTCOINS", 0))
if item.get_closest_marker("altcoin") and skip_altcoins:
pytest.skip("Skipping altcoin test")
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
# Make test results available in fixtures.
# See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
# The device_handler fixture uses this as 'request.node.rep_call.passed' attribute,
# in order to raise error only if the test passed.
outcome = yield
rep = outcome.get_result()
setattr(item, f"rep_{rep.when}", rep)
@pytest.fixture
def device_handler(client, request):
device_handler = BackgroundDeviceHandler(client)
yield device_handler
# make sure all background tasks are done
finalized_ok = device_handler.check_finalize()
if request.node.rep_call.passed and not finalized_ok:
raise RuntimeError("Test did not check result of background task")

@ -0,0 +1,84 @@
from concurrent.futures import ThreadPoolExecutor
from trezorlib.transport import udp
udp.SOCKET_TIMEOUT = 0.1
class NullUI:
@staticmethod
def button_request(code):
pass
@staticmethod
def get_pin(code=None):
raise NotImplementedError("Should not be used with T1")
@staticmethod
def get_passphrase():
raise NotImplementedError("Should not be used with T1")
class BackgroundDeviceHandler:
_pool = ThreadPoolExecutor()
def __init__(self, client):
self.client = client
self.client.ui = NullUI
self.task = None
def run(self, function, *args, **kwargs):
if self.task is not None:
raise RuntimeError("Wait for previous task first")
self.task = self._pool.submit(function, self.client, *args, **kwargs)
def kill_task(self):
if self.task is not None:
# Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have
# a close() method.
while self.client.session_counter > 0:
self.client.close()
try:
self.task.result()
except Exception:
pass
self.task = None
def restart(self, emulator):
# TODO handle actual restart as well
self.kill_task()
emulator.restart()
self.client = emulator.client
self.client.ui = NullUI
def result(self):
if self.task is None:
raise RuntimeError("No task running")
try:
return self.task.result()
finally:
self.task = None
def features(self):
if self.task is not None:
raise RuntimeError("Cannot query features while task is running")
self.client.init_device()
return self.client.features
def debuglink(self):
return self.client.debug
def check_finalize(self):
if self.task is not None:
self.kill_task()
return False
return True
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
finalized_ok = self.check_finalize()
if exc_type is None and not finalized_ok:
raise RuntimeError("Exit while task is unfinished")

@ -14,6 +14,7 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import gzip
import os
import subprocess
import tempfile
@ -21,15 +22,17 @@ import time
from collections import defaultdict
from trezorlib.debuglink import TrezorClientDebugLink
from trezorlib.transport import TransportException, get_transport
from trezorlib.transport.udp import UdpTransport
BINDIR = os.path.dirname(os.path.abspath(__file__)) + "/emulators"
ROOT = os.path.dirname(os.path.abspath(__file__)) + "/../"
ROOT = os.path.abspath(os.path.dirname(__file__) + "/..")
BINDIR = ROOT + "/tests/emulators"
LOCAL_BUILD_PATHS = {
"core": ROOT + "core/build/unix/micropython",
"legacy": ROOT + "legacy/firmware/trezor.elf",
"core": ROOT + "/core/build/unix/micropython",
"legacy": ROOT + "/legacy/firmware/trezor.elf",
}
SD_CARD_GZ = ROOT + "/tests/trezor.sdcard.gz"
ENV = {"SDL_VIDEODRIVER": "dummy"}
@ -84,46 +87,68 @@ class EmulatorWrapper:
if storage:
open(self._storage_file(), "wb").write(storage)
with gzip.open(SD_CARD_GZ, "rb") as gz:
with open(self.workdir.name + "/trezor.sdcard", "wb") as sd:
sd.write(gz.read())
self.client = None
def __enter__(self):
def _get_params_core(self):
env = ENV.copy()
args = [self.executable, "-m", "main"]
# for firmware 2.1.2 and newer
env["TREZOR_PROFILE_DIR"] = self.workdir.name
# for firmware 2.1.1 and older
env["TREZOR_PROFILE"] = self.workdir.name
if self.executable == LOCAL_BUILD_PATHS["core"]:
cwd = ROOT + "/core/src"
else:
cwd = self.workdir.name
return env, args, cwd
def _get_params_legacy(self):
env = ENV.copy()
args = [self.executable]
env = ENV
cwd = self.workdir.name
return env, args, cwd
def _get_params(self):
if self.gen == "core":
args += ["-m", "main"]
# for firmware 2.1.2 and newer
env["TREZOR_PROFILE_DIR"] = self.workdir.name
# for firmware 2.1.1 and older
env["TREZOR_PROFILE"] = self.workdir.name
return self._get_params_core()
elif self.gen == "legacy":
return self._get_params_legacy()
else:
raise ValueError("Unknown gen")
def start(self):
env, args, cwd = self._get_params()
self.process = subprocess.Popen(
args, cwd=self.workdir.name, env=env, stdout=open(os.devnull, "w")
args, cwd=cwd, env=env, stdout=open(os.devnull, "w")
)
# wait until emulator is listening
transport = UdpTransport("127.0.0.1:21324")
transport.open()
for _ in range(300):
try:
time.sleep(0.1)
transport = get_transport("udp:127.0.0.1:21324")
if transport._ping():
break
except TransportException:
pass
if self.process.poll() is not None:
self._cleanup()
raise RuntimeError("Emulator proces died")
time.sleep(0.1)
else:
# could not connect after 300 attempts * 0.1s = 30s of waiting
self._cleanup()
raise RuntimeError("Can't connect to emulator")
transport.close()
self.client = TrezorClientDebugLink(transport)
self.client.open()
check_version(self.tag, self.client.version)
return self
def __exit__(self, exc_type, exc_value, traceback):
self._cleanup()
return False
def _cleanup(self):
def stop(self):
if self.client:
self.client.close()
self.process.terminate()
@ -131,6 +156,20 @@ class EmulatorWrapper:
self.process.wait(1)
except subprocess.TimeoutExpired:
self.process.kill()
def restart(self):
self.stop()
self.start()
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self._cleanup()
def _cleanup(self):
self.stop()
self.workdir.cleanup()
def _storage_file(self):

Binary file not shown.

@ -0,0 +1,37 @@
import os
import pytest
from ..emulators import EmulatorWrapper
SELECTED_GENS = [
gen.strip() for gen in os.environ.get("TREZOR_UPGRADE_TEST", "").split(",") if gen
]
if SELECTED_GENS:
# if any gens were selected via the environment variable, force enable all selected
LEGACY_ENABLED = "legacy" in SELECTED_GENS
CORE_ENABLED = "core" in SELECTED_GENS
else:
# if no selection was provided, select those for which we have emulators
try:
EmulatorWrapper("legacy")
LEGACY_ENABLED = True
except Exception:
LEGACY_ENABLED = False
try:
EmulatorWrapper("core")
CORE_ENABLED = True
except Exception:
CORE_ENABLED = False
legacy_only = pytest.mark.skipif(
not LEGACY_ENABLED, reason="This test requires legacy emulator"
)
core_only = pytest.mark.skipif(
not CORE_ENABLED, reason="This test requires core emulator"
)

@ -14,14 +14,13 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os
import pytest
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device
from trezorlib.tools import H_
from ..emulators import ALL_TAGS, EmulatorWrapper
from . import SELECTED_GENS
MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0)
MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0)
@ -41,11 +40,8 @@ def for_all(*args, minimum_version=(1, 0, 0)):
if not args:
args = ("core", "legacy")
specified_gens = os.environ.get("TREZOR_UPGRADE_TEST")
if specified_gens is not None:
enabled_gens = specified_gens.split(",")
else:
enabled_gens = args
# If any gens were selected, use them. If none, select all.
enabled_gens = SELECTED_GENS or args
all_params = []
for gen in args:

@ -0,0 +1,72 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2019 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import pytest
from trezorlib import device
from .. import buttons
from ..device_handler import BackgroundDeviceHandler
from ..emulators import EmulatorWrapper
from . import core_only
def enter_word(debug, word):
word = word[:4]
for coords in buttons.type_word(word):
debug.click(coords)
return debug.click(buttons.CONFIRM_WORD, wait=True)
@pytest.fixture
def emulator():
emu = EmulatorWrapper("core")
with emu:
yield emu
@core_only
def test_persistence(emulator):
device_handler = BackgroundDeviceHandler(emulator.client)
debug = device_handler.debuglink()
features = device_handler.features()
assert features.recovery_mode is False
device_handler.run(device.recover, pin_protection=False)
layout = debug.wait_layout()
assert layout.text.startswith("Recovery mode")
layout = debug.click(buttons.OK, wait=True)
assert "Select number of words" in layout.text
device_handler.restart(emulator)
debug = device_handler.debuglink()
features = device_handler.features()
assert features.recovery_mode is True
# no waiting for layout because layout doesn't change
layout = debug.read_layout()
assert "Select number of words" in layout.text
layout = debug.click(buttons.CANCEL, wait=True)
assert layout.text.startswith("Abort recovery")
layout = debug.click(buttons.OK, wait=True)
assert layout.text == "Homescreen"
features = device_handler.features()
assert features.recovery_mode is False
Loading…
Cancel
Save