1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-12 08:20:56 +00:00

chore(tests): add type hints to helper test functions

This commit is contained in:
grdddj 2022-01-28 19:26:03 +01:00 committed by matejcik
parent 5d76144ef5
commit c77e18d77c
14 changed files with 174 additions and 101 deletions

View File

@ -18,6 +18,7 @@ import hashlib
import hmac import hmac
import struct import struct
from copy import copy from copy import copy
from typing import Any, List, Tuple
import ecdsa import ecdsa
from ecdsa.curves import SECP256k1 from ecdsa.curves import SECP256k1
@ -27,7 +28,7 @@ from ecdsa.util import number_to_string, string_to_number
from trezorlib import messages, tools from trezorlib import messages, tools
def point_to_pubkey(point): def point_to_pubkey(point: Point) -> bytes:
order = SECP256k1.order order = SECP256k1.order
x_str = number_to_string(point.x(), order) x_str = number_to_string(point.x(), order)
y_str = number_to_string(point.y(), order) y_str = number_to_string(point.y(), order)
@ -35,14 +36,14 @@ def point_to_pubkey(point):
return struct.pack("B", (vk[63] & 1) + 2) + vk[0:32] # To compressed key return struct.pack("B", (vk[63] & 1) + 2) + vk[0:32] # To compressed key
def sec_to_public_pair(pubkey): def sec_to_public_pair(pubkey: bytes) -> Tuple[int, Any]:
"""Convert a public key in sec binary format to a public pair.""" """Convert a public key in sec binary format to a public pair."""
x = string_to_number(pubkey[1:33]) x = string_to_number(pubkey[1:33])
sec0 = pubkey[:1] sec0 = pubkey[:1]
if sec0 not in (b"\2", b"\3"): if sec0 not in (b"\2", b"\3"):
raise ValueError("Compressed pubkey expected") raise ValueError("Compressed pubkey expected")
def public_pair_for_x(generator, x, is_even): def public_pair_for_x(generator, x: int, is_even: bool) -> Tuple[int, Any]:
curve = generator.curve() curve = generator.curve()
p = curve.p() p = curve.p()
alpha = (pow(x, 3, p) + curve.a() * x + curve.b()) % p alpha = (pow(x, 3, p) + curve.a() * x + curve.b()) % p
@ -56,15 +57,15 @@ def sec_to_public_pair(pubkey):
) )
def fingerprint(pubkey): def fingerprint(pubkey: bytes) -> int:
return string_to_number(tools.hash_160(pubkey)[:4]) return string_to_number(tools.hash_160(pubkey)[:4])
def get_address(public_node, address_type): def get_address(public_node: messages.HDNodeType, address_type: int) -> str:
return tools.public_key_to_bc_address(public_node.public_key, address_type) return tools.public_key_to_bc_address(public_node.public_key, address_type)
def public_ckd(public_node, n): def public_ckd(public_node: messages.HDNodeType, n: List[int]):
if not isinstance(n, list): if not isinstance(n, list):
raise ValueError("Parameter must be a list") raise ValueError("Parameter must be a list")
@ -76,7 +77,7 @@ def public_ckd(public_node, n):
return node return node
def get_subnode(node, i): def get_subnode(node: messages.HDNodeType, i: int) -> messages.HDNodeType:
# Public Child key derivation (CKD) algorithm of BIP32 # Public Child key derivation (CKD) algorithm of BIP32
i_as_bytes = struct.pack(">L", i) i_as_bytes = struct.pack(">L", i)
@ -108,7 +109,7 @@ def get_subnode(node, i):
) )
def serialize(node, version=0x0488B21E): def serialize(node: messages.HDNodeType, version: int = 0x0488B21E) -> str:
s = b"" s = b""
s += struct.pack(">I", version) s += struct.pack(">I", version)
s += struct.pack(">B", node.depth) s += struct.pack(">B", node.depth)
@ -123,7 +124,7 @@ def serialize(node, version=0x0488B21E):
return tools.b58encode(s) return tools.b58encode(s)
def deserialize(xpub): def deserialize(xpub: str) -> messages.HDNodeType:
data = tools.b58decode(xpub, None) data = tools.b58decode(xpub, None)
if tools.btc_hash(data[:-4])[:4] != data[-4:]: if tools.btc_hash(data[:-4])[:4] != data[-4:]:

View File

@ -1,8 +1,10 @@
from typing import Iterator, Tuple
DISPLAY_WIDTH = 240 DISPLAY_WIDTH = 240
DISPLAY_HEIGHT = 240 DISPLAY_HEIGHT = 240
def grid(dim, grid_cells, cell): def grid(dim: int, grid_cells: int, cell: int) -> int:
step = dim // grid_cells step = dim // grid_cells
ofs = step // 2 ofs = step // 2
return cell * step + ofs return cell * step + ofs
@ -34,15 +36,15 @@ RESET_WORD_CHECK = [
BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz") BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
def grid35(x, y): def grid35(x: int, y: int) -> Tuple[int, int]:
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y) return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y)
def grid34(x, y): def grid34(x: int, y: int) -> Tuple[int, int]:
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y) return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y)
def type_word(word): def type_word(word: str) -> Iterator[Tuple[int, int]]:
for l in word: for l in word:
idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters) idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters)
grid_x = idx % 3 grid_x = idx % 3

View File

@ -18,12 +18,18 @@ import json
import os import os
from decimal import Decimal from decimal import Decimal
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Generator, List, Optional
import pytest import pytest
import requests import requests
from trezorlib import btc, tools from trezorlib import btc, tools
from trezorlib.messages import ButtonRequestType as B from trezorlib.messages import ButtonRequestType
if TYPE_CHECKING:
from trezorlib.debuglink import DebugLink, TrezorClientDebugLink as Client
from trezorlib.messages import ButtonRequest
from _pytest.mark.structures import MarkDecorator
# fmt: off # fmt: off
# 1 2 3 4 5 6 7 8 9 10 11 12 # 1 2 3 4 5 6 7 8 9 10 11 12
@ -56,7 +62,7 @@ COMMON_FIXTURES_DIR = (
) )
def parametrize_using_common_fixtures(*paths): def parametrize_using_common_fixtures(*paths: str) -> "MarkDecorator":
fixtures = [] fixtures = []
for path in paths: for path in paths:
fixtures.append(json.loads((COMMON_FIXTURES_DIR / path).read_text())) fixtures.append(json.loads((COMMON_FIXTURES_DIR / path).read_text()))
@ -85,7 +91,9 @@ def parametrize_using_common_fixtures(*paths):
return pytest.mark.parametrize("parameters, result", tests) return pytest.mark.parametrize("parameters, result", tests)
def generate_entropy(strength, internal_entropy, external_entropy): def generate_entropy(
strength: int, internal_entropy: bytes, external_entropy: bytes
) -> bytes:
""" """
strength - length of produced seed. One of 128, 192, 256 strength - length of produced seed. One of 128, 192, 256
random - binary stream of random data from external HRNG random - binary stream of random data from external HRNG
@ -116,7 +124,12 @@ def generate_entropy(strength, internal_entropy, external_entropy):
return entropy_stripped return entropy_stripped
def recovery_enter_shares(debug, shares, groups=False, click_info=False): def recovery_enter_shares(
debug: "DebugLink",
shares: List[str],
groups: bool = False,
click_info: bool = False,
) -> Generator[None, "ButtonRequest", None]:
"""Perform the recovery flow for a set of Shamir shares. """Perform the recovery flow for a set of Shamir shares.
For use in an input flow function. For use in an input flow function.
@ -134,7 +147,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
debug.press_yes() debug.press_yes()
# Input word number # Input word number
br = yield br = yield
assert br.code == B.MnemonicWordCount assert br.code == ButtonRequestType.MnemonicWordCount
debug.input(str(word_count)) debug.input(str(word_count))
# Homescreen - proceed to share entry # Homescreen - proceed to share entry
yield yield
@ -142,7 +155,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
# Enter shares # Enter shares
for share in shares: for share in shares:
br = yield br = yield
assert br.code == B.MnemonicInput assert br.code == ButtonRequestType.MnemonicInput
# Enter mnemonic words # Enter mnemonic words
for word in share.split(" "): for word in share.split(" "):
debug.input(word) debug.input(word)
@ -167,7 +180,9 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
debug.press_yes() debug.press_yes()
def click_through(debug, screens, code=None): def click_through(
debug: "DebugLink", screens: int, code: ButtonRequestType = None
) -> Generator[None, "ButtonRequest", None]:
"""Click through N dialog screens. """Click through N dialog screens.
For use in an input flow function. For use in an input flow function.
@ -178,7 +193,7 @@ def click_through(debug, screens, code=None):
# 2. Backup your seed # 2. Backup your seed
# 3. Confirm warning # 3. Confirm warning
# 4. Shares info # 4. Shares info
yield from click_through(client.debug, screens=4, code=B.ResetDevice) yield from click_through(client.debug, screens=4, code=ButtonRequestType.ResetDevice)
""" """
for _ in range(screens): for _ in range(screens):
received = yield received = yield
@ -187,7 +202,9 @@ def click_through(debug, screens, code=None):
debug.press_yes() debug.press_yes()
def read_and_confirm_mnemonic(debug, choose_wrong=False): def read_and_confirm_mnemonic(
debug: "DebugLink", choose_wrong: bool = False
) -> Generator[None, "ButtonRequest", Optional[str]]:
"""Read a given number of mnemonic words from Trezor T screen and correctly """Read a given number of mnemonic words from Trezor T screen and correctly
answer confirmation questions. Return the full mnemonic. answer confirmation questions. Return the full mnemonic.
@ -201,6 +218,7 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
""" """
mnemonic = [] mnemonic = []
br = yield br = yield
assert br.pages is not None
for _ in range(br.pages - 1): for _ in range(br.pages - 1):
mnemonic.extend(debug.read_reset_word().split()) mnemonic.extend(debug.read_reset_word().split())
debug.swipe_up(wait=True) debug.swipe_up(wait=True)
@ -221,13 +239,15 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
return " ".join(mnemonic) return " ".join(mnemonic)
def get_test_address(client): def get_test_address(client: "Client") -> str:
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
protected call, or to identify the root secret (seed+passphrase)""" protected call, or to identify the root secret (seed+passphrase)"""
return btc.get_address(client, "Testnet", TEST_ADDRESS_N) return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
def assert_tx_matches(serialized_tx: bytes, hash_link: str, tx_hex: str = None) -> None: def assert_tx_matches(
serialized_tx: bytes, hash_link: str, tx_hex: str = None
) -> None:
"""Verifies if a transaction is correctly formed.""" """Verifies if a transaction is correctly formed."""
hash_str = hash_link.split("/")[-1] hash_str = hash_link.split("/")[-1]
assert tools.tx_hash(serialized_tx).hex() == hash_str assert tools.tx_hash(serialized_tx).hex() == hash_str

View File

@ -15,6 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os import os
from typing import TYPE_CHECKING
import pytest import pytest
@ -27,12 +28,17 @@ from . import ui_tests
from .device_handler import BackgroundDeviceHandler from .device_handler import BackgroundDeviceHandler
from .ui_tests.reporting import testreport from .ui_tests.reporting import testreport
if TYPE_CHECKING:
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.terminal import TerminalReporter
# So that we see details of failed asserts from this module # So that we see details of failed asserts from this module
pytest.register_assert_rewrite("tests.common") pytest.register_assert_rewrite("tests.common")
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def _raw_client(request): def _raw_client(request: pytest.FixtureRequest) -> TrezorClientDebugLink:
path = os.environ.get("TREZOR_PATH") path = os.environ.get("TREZOR_PATH")
interact = int(os.environ.get("INTERACT", 0)) interact = int(os.environ.get("INTERACT", 0))
if path: if path:
@ -56,7 +62,9 @@ def _raw_client(request):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def client(request, _raw_client): def client(
request: pytest.FixtureRequest, _raw_client: TrezorClientDebugLink
) -> TrezorClientDebugLink:
"""Client fixture. """Client fixture.
Every test function that requires a client instance will get it from here. Every test function that requires a client instance will get it from here.
@ -156,20 +164,20 @@ def client(request, _raw_client):
_raw_client.close() _raw_client.close()
def pytest_sessionstart(session): def pytest_sessionstart(session: pytest.Session) -> None:
ui_tests.read_fixtures() ui_tests.read_fixtures()
if session.config.getoption("ui") == "test": if session.config.getoption("ui") == "test":
testreport.clear_dir() testreport.clear_dir()
def _should_write_ui_report(exitstatus): def _should_write_ui_report(exitstatus: pytest.ExitCode) -> bool:
# generate UI report and check missing only if pytest is exitting cleanly # generate UI report and check missing only if pytest is exitting cleanly
# I.e., the test suite passed or failed (as opposed to ctrl+c break, internal error, # I.e., the test suite passed or failed (as opposed to ctrl+c break, internal error,
# etc.) # etc.)
return exitstatus in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED) return exitstatus in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED)
def pytest_sessionfinish(session, exitstatus): def pytest_sessionfinish(session: pytest.Session, exitstatus: pytest.ExitCode) -> None:
if not _should_write_ui_report(exitstatus): if not _should_write_ui_report(exitstatus):
return return
@ -183,7 +191,9 @@ def pytest_sessionfinish(session, exitstatus):
ui_tests.write_fixtures(missing) ui_tests.write_fixtures(missing)
def pytest_terminal_summary(terminalreporter, exitstatus, config): def pytest_terminal_summary(
terminalreporter: "TerminalReporter", exitstatus: pytest.ExitCode, config: "Config"
) -> None:
println = terminalreporter.write_line println = terminalreporter.write_line
println("") println("")
@ -213,7 +223,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
println("") println("")
def pytest_addoption(parser): def pytest_addoption(parser: "Parser") -> None:
parser.addoption( parser.addoption(
"--ui", "--ui",
action="store", action="store",
@ -229,7 +239,7 @@ def pytest_addoption(parser):
) )
def pytest_configure(config): def pytest_configure(config: "Config") -> None:
"""Called at testsuite setup time. """Called at testsuite setup time.
Registers known markers, enables verbose output if requested. Registers known markers, enables verbose output if requested.
@ -253,7 +263,7 @@ def pytest_configure(config):
log.enable_debug_output() log.enable_debug_output()
def pytest_runtest_setup(item): def pytest_runtest_setup(item: pytest.Item) -> None:
"""Called for each test item (class, individual tests). """Called for each test item (class, individual tests).
Ensures that altcoin tests are skipped, and that no test is skipped on Ensures that altcoin tests are skipped, and that no test is skipped on
@ -267,7 +277,7 @@ def pytest_runtest_setup(item):
pytest.skip("Skipping altcoin test") pytest.skip("Skipping altcoin test")
def pytest_runtest_teardown(item): def pytest_runtest_teardown(item: pytest.Item) -> None:
"""Called after a test item finishes. """Called after a test item finishes.
Dumps the current UI test report HTML. Dumps the current UI test report HTML.
@ -277,7 +287,7 @@ def pytest_runtest_teardown(item):
@pytest.hookimpl(tryfirst=True, hookwrapper=True) @pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call): def pytest_runtest_makereport(item: pytest.Item, call) -> None:
# Make test results available in fixtures. # Make test results available in fixtures.
# See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures # See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
# The device_handler fixture uses this as 'request.node.rep_call.passed' attribute, # The device_handler fixture uses this as 'request.node.rep_call.passed' attribute,
@ -288,7 +298,9 @@ def pytest_runtest_makereport(item, call):
@pytest.fixture @pytest.fixture
def device_handler(client, request): def device_handler(
client: TrezorClientDebugLink, request: pytest.FixtureRequest
) -> None:
device_handler = BackgroundDeviceHandler(client) device_handler = BackgroundDeviceHandler(client)
yield device_handler yield device_handler
@ -299,5 +311,5 @@ def device_handler(client, request):
# if test finished, make sure all background tasks are done # if test finished, make sure all background tasks are done
finalized_ok = device_handler.check_finalize() finalized_ok = device_handler.check_finalize()
if request.node.rep_call.passed and not finalized_ok: if request.node.rep_call.passed and not finalized_ok: # type: ignore [rep_call must exist]
raise RuntimeError("Test did not check result of background task") raise RuntimeError("Test did not check result of background task")

View File

@ -1,8 +1,15 @@
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING
from trezorlib.client import PASSPHRASE_ON_DEVICE from trezorlib.client import PASSPHRASE_ON_DEVICE
from trezorlib.transport import udp from trezorlib.transport import udp
if TYPE_CHECKING:
from trezorlib.messages import Features
from trezorlib.debuglink import TrezorClientDebugLink, DebugLink
from trezorlib._internal.emulator import Emulator
udp.SOCKET_TIMEOUT = 0.1 udp.SOCKET_TIMEOUT = 0.1
@ -26,13 +33,13 @@ class NullUI:
class BackgroundDeviceHandler: class BackgroundDeviceHandler:
_pool = ThreadPoolExecutor() _pool = ThreadPoolExecutor()
def __init__(self, client): def __init__(self, client: "TrezorClientDebugLink") -> None:
self._configure_client(client) self._configure_client(client)
self.task = None self.task = None
def _configure_client(self, client): def _configure_client(self, client: "TrezorClientDebugLink") -> None:
self.client = client self.client = client
self.client.ui = NullUI self.client.ui = NullUI # type: ignore [NullUI is OK UI]
self.client.watch_layout(True) self.client.watch_layout(True)
def run(self, function, *args, **kwargs): def run(self, function, *args, **kwargs):
@ -40,7 +47,7 @@ class BackgroundDeviceHandler:
raise RuntimeError("Wait for previous task first") raise RuntimeError("Wait for previous task first")
self.task = self._pool.submit(function, self.client, *args, **kwargs) self.task = self._pool.submit(function, self.client, *args, **kwargs)
def kill_task(self): def kill_task(self) -> None:
if self.task is not None: if self.task is not None:
# Force close the client, which should raise an exception in a client # Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have # waiting on IO. Does not work over Bridge, because bridge doesn't have
@ -53,11 +60,11 @@ class BackgroundDeviceHandler:
pass pass
self.task = None self.task = None
def restart(self, emulator): def restart(self, emulator: "Emulator"):
# TODO handle actual restart as well # TODO handle actual restart as well
self.kill_task() self.kill_task()
emulator.restart() emulator.restart()
self._configure_client(emulator.client) self._configure_client(emulator.client) # type: ignore [client cannot be None]
def result(self): def result(self):
if self.task is None: if self.task is None:
@ -67,25 +74,25 @@ class BackgroundDeviceHandler:
finally: finally:
self.task = None self.task = None
def features(self): def features(self) -> "Features":
if self.task is not None: if self.task is not None:
raise RuntimeError("Cannot query features while task is running") raise RuntimeError("Cannot query features while task is running")
self.client.init_device() self.client.init_device()
return self.client.features return self.client.features
def debuglink(self): def debuglink(self) -> "DebugLink":
return self.client.debug return self.client.debug
def check_finalize(self): def check_finalize(self) -> bool:
if self.task is not None: if self.task is not None:
self.kill_task() self.kill_task()
return False return False
return True return True
def __enter__(self): def __enter__(self) -> "BackgroundDeviceHandler":
return self return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback) -> None:
finalized_ok = self.check_finalize() finalized_ok = self.check_finalize()
if exc_type is None and not finalized_ok: if exc_type is None and not finalized_ok:
raise RuntimeError("Exit while task is unfinished") raise RuntimeError("Exit while task is unfinished")

View File

@ -17,8 +17,9 @@
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple
from trezorlib._internal.emulator import CoreEmulator, LegacyEmulator from trezorlib._internal.emulator import CoreEmulator, Emulator, LegacyEmulator
ROOT = Path(__file__).resolve().parent.parent ROOT = Path(__file__).resolve().parent.parent
BINDIR = ROOT / "tests" / "emulators" BINDIR = ROOT / "tests" / "emulators"
@ -33,18 +34,18 @@ CORE_SRC_DIR = ROOT / "core" / "src"
ENV = {"SDL_VIDEODRIVER": "dummy"} ENV = {"SDL_VIDEODRIVER": "dummy"}
def check_version(tag, version_tuple): def check_version(tag: str, version_tuple: Tuple[int, int, int]) -> None:
if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3: if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3:
version = ".".join(str(i) for i in version_tuple) version = ".".join(str(i) for i in version_tuple)
if tag[1:] != version: if tag[1:] != version:
raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}") raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}")
def filename_from_tag(gen, tag): def filename_from_tag(gen: str, tag: str) -> Path:
return BINDIR / f"trezor-emu-{gen}-{tag}" return BINDIR / f"trezor-emu-{gen}-{tag}"
def get_tags(): def get_tags() -> Dict[str, List[str]]:
files = list(BINDIR.iterdir()) files = list(BINDIR.iterdir())
if not files: if not files:
raise ValueError( raise ValueError(
@ -66,7 +67,7 @@ ALL_TAGS = get_tags()
class EmulatorWrapper: class EmulatorWrapper:
def __init__(self, gen, tag=None, storage=None): def __init__(self, gen: str, tag: str = None, storage: bytes = None) -> None:
if tag is not None: if tag is not None:
executable = filename_from_tag(gen, tag) executable = filename_from_tag(gen, tag)
else: else:
@ -96,11 +97,15 @@ class EmulatorWrapper:
workdir=workdir, workdir=workdir,
headless=True, headless=True,
) )
else:
raise ValueError(
f"Unrecognized gen - {gen} - only 'core' and 'legacy' supported"
)
def __enter__(self): def __enter__(self) -> Emulator:
self.emulator.start() self.emulator.start()
return self.emulator return self.emulator
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback) -> None:
self.emulator.stop() self.emulator.stop()
self.profile_dir.cleanup() self.profile_dir.cleanup()

View File

@ -5,9 +5,9 @@ import multiprocessing
import os import os
import posixpath import posixpath
import time import time
import urllib
import webbrowser import webbrowser
from pathlib import Path from pathlib import Path
from urllib.parse import unquote
import click import click
@ -16,16 +16,16 @@ TEST_RESULT_PATH = ROOT / "tests" / "ui_tests" / "reporting" / "reports" / "test
class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler): class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
def end_headers(self): def end_headers(self) -> None:
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate") self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
self.send_header("Pragma", "no-cache") self.send_header("Pragma", "no-cache")
self.send_header("Expires", "0") self.send_header("Expires", "0")
return super().end_headers() return super().end_headers()
def log_message(self, format, *args): def log_message(self, format, *args) -> None:
pass pass
def translate_path(self, path): def translate_path(self, path: str) -> str:
# XXX # XXX
# Copy-pasted from Python 3.8 BaseHTTPRequestHandler so that we can inject # Copy-pasted from Python 3.8 BaseHTTPRequestHandler so that we can inject
# the `directory` parameter. # the `directory` parameter.
@ -36,9 +36,9 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
# Don't forget explicit trailing slash when normalizing. Issue17324 # Don't forget explicit trailing slash when normalizing. Issue17324
trailing_slash = path.rstrip().endswith("/") trailing_slash = path.rstrip().endswith("/")
try: try:
path = urllib.parse.unquote(path, errors="surrogatepass") path = unquote(path, errors="surrogatepass")
except UnicodeDecodeError: except UnicodeDecodeError:
path = urllib.parse.unquote(path) path = unquote(path)
path = posixpath.normpath(path) path = posixpath.normpath(path)
words = path.split("/") words = path.split("/")
words = filter(None, words) words = filter(None, words)
@ -53,13 +53,13 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
return path return path
def launch_http_server(port): def launch_http_server(port: int) -> None:
http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port) http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port) # type: ignore [test is defined]
@click.command() @click.command()
@click.option("-p", "--port", type=int, default=8000) @click.option("-p", "--port", type=int, default=8000)
def main(port): def main(port: int):
httpd = multiprocessing.Process(target=launch_http_server, args=(port,)) httpd = multiprocessing.Process(target=launch_http_server, args=(port,))
httpd.start() httpd.start()
time.sleep(0.5) time.sleep(0.5)

View File

@ -122,7 +122,7 @@ def cli(tx, coin_name):
tx_dict = protobuf.to_dict(tx_proto) tx_dict = protobuf.to_dict(tx_proto)
tx_json = json.dumps(tx_dict, sort_keys=True, indent=2) + "\n" tx_json = json.dumps(tx_dict, sort_keys=True, indent=2) + "\n"
except Exception as e: except Exception as e:
raise click.ClickException(e) from e raise click.ClickException(str(e)) from e
cache_dir = CACHE_PATH / coin_name cache_dir = CACHE_PATH / coin_name
if not cache_dir.exists(): if not cache_dir.exists():

View File

@ -4,23 +4,26 @@ import re
import shutil import shutil
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, Generator, Set
import pytest import pytest
from _pytest.outcomes import Failed from _pytest.outcomes import Failed
from PIL import Image from PIL import Image
from trezorlib.debuglink import TrezorClientDebugLink as Client
from .reporting import testreport from .reporting import testreport
UI_TESTS_DIR = Path(__file__).resolve().parent UI_TESTS_DIR = Path(__file__).resolve().parent
SCREENS_DIR = UI_TESTS_DIR / "screens" SCREENS_DIR = UI_TESTS_DIR / "screens"
HASH_FILE = UI_TESTS_DIR / "fixtures.json" HASH_FILE = UI_TESTS_DIR / "fixtures.json"
SUGGESTION_FILE = UI_TESTS_DIR / "fixtures.suggestion.json" SUGGESTION_FILE = UI_TESTS_DIR / "fixtures.suggestion.json"
FILE_HASHES = {} FILE_HASHES: Dict[str, str] = {}
ACTUAL_HASHES = {} ACTUAL_HASHES: Dict[str, str] = {}
PROCESSED = set() PROCESSED: Set[str] = set()
def get_test_name(node_id): def get_test_name(node_id: str) -> str:
# Test item name is usually function name, but when parametrization is used, # Test item name is usually function name, but when parametrization is used,
# parameters are also part of the name. Some functions have very long parameter # parameters are also part of the name. Some functions have very long parameter
# names (tx hashes etc) that run out of maximum allowable filename length, so # names (tx hashes etc) that run out of maximum allowable filename length, so
@ -34,14 +37,14 @@ def get_test_name(node_id):
return new_name[:91] + "-" + hashlib.sha256(new_name.encode()).hexdigest()[:8] return new_name[:91] + "-" + hashlib.sha256(new_name.encode()).hexdigest()[:8]
def _process_recorded(screen_path, test_name): def _process_recorded(screen_path: Path, test_name: str) -> None:
# calculate hash # calculate hash
FILE_HASHES[test_name] = _hash_files(screen_path) FILE_HASHES[test_name] = _hash_files(screen_path)
_rename_records(screen_path) _rename_records(screen_path)
PROCESSED.add(test_name) PROCESSED.add(test_name)
def _rename_records(screen_path): def _rename_records(screen_path: Path) -> None:
# rename screenshots # rename screenshots
for index, record in enumerate(sorted(screen_path.iterdir())): for index, record in enumerate(sorted(screen_path.iterdir())):
record.replace(screen_path / f"{index:08}.png") record.replace(screen_path / f"{index:08}.png")
@ -65,7 +68,7 @@ def _get_bytes_from_png(png_file: str) -> bytes:
return Image.open(png_file).tobytes() return Image.open(png_file).tobytes()
def _process_tested(fixture_test_path, test_name): def _process_tested(fixture_test_path: Path, test_name: str) -> None:
PROCESSED.add(test_name) PROCESSED.add(test_name)
actual_path = fixture_test_path / "actual" actual_path = fixture_test_path / "actual"
@ -79,6 +82,7 @@ def _process_tested(fixture_test_path, test_name):
pytest.fail(f"Hash of {test_name} not found in fixtures.json") pytest.fail(f"Hash of {test_name} not found in fixtures.json")
if actual_hash != expected_hash: if actual_hash != expected_hash:
assert expected_hash is not None
file_path = testreport.failed( file_path = testreport.failed(
fixture_test_path, test_name, actual_hash, expected_hash fixture_test_path, test_name, actual_hash, expected_hash
) )
@ -94,7 +98,9 @@ def _process_tested(fixture_test_path, test_name):
@contextmanager @contextmanager
def screen_recording(client, request): def screen_recording(
client: Client, request: pytest.FixtureRequest
) -> Generator[None, None, None]:
test_ui = request.config.getoption("ui") test_ui = request.config.getoption("ui")
test_name = get_test_name(request.node.nodeid) test_name = get_test_name(request.node.nodeid)
screens_test_path = SCREENS_DIR / test_name screens_test_path = SCREENS_DIR / test_name
@ -126,26 +132,26 @@ def screen_recording(client, request):
_process_tested(screens_test_path, test_name) _process_tested(screens_test_path, test_name)
def list_missing(): def list_missing() -> Set[str]:
return set(FILE_HASHES.keys()) - PROCESSED return set(FILE_HASHES.keys()) - PROCESSED
def read_fixtures(): def read_fixtures() -> None:
if not HASH_FILE.exists(): if not HASH_FILE.exists():
raise ValueError("File fixtures.json not found.") raise ValueError("File fixtures.json not found.")
global FILE_HASHES global FILE_HASHES
FILE_HASHES = json.loads(HASH_FILE.read_text()) FILE_HASHES = json.loads(HASH_FILE.read_text())
def write_fixtures(remove_missing: bool): def write_fixtures(remove_missing: bool) -> None:
HASH_FILE.write_text(_get_fixtures_content(FILE_HASHES, remove_missing)) HASH_FILE.write_text(_get_fixtures_content(FILE_HASHES, remove_missing))
def write_fixtures_suggestion(remove_missing: bool): def write_fixtures_suggestion(remove_missing: bool) -> None:
SUGGESTION_FILE.write_text(_get_fixtures_content(ACTUAL_HASHES, remove_missing)) SUGGESTION_FILE.write_text(_get_fixtures_content(ACTUAL_HASHES, remove_missing))
def _get_fixtures_content(fixtures: dict, remove_missing: bool): def _get_fixtures_content(fixtures: Dict[str, str], remove_missing: bool) -> str:
if remove_missing: if remove_missing:
fixtures = {i: fixtures[i] for i in PROCESSED} fixtures = {i: fixtures[i] for i in PROCESSED}
else: else:
@ -154,7 +160,7 @@ def _get_fixtures_content(fixtures: dict, remove_missing: bool):
return json.dumps(fixtures, indent="", sort_keys=True) + "\n" return json.dumps(fixtures, indent="", sort_keys=True) + "\n"
def main(): def main() -> None:
read_fixtures() read_fixtures()
for record in SCREENS_DIR.iterdir(): for record in SCREENS_DIR.iterdir():
if not (record / "actual").exists(): if not (record / "actual").exists():

View File

@ -14,7 +14,7 @@ FIXTURES_CURRENT = Path(__file__).resolve().parent.parent / "fixtures.json"
_dns_failed = False _dns_failed = False
def fetch_recorded(hash, path): def fetch_recorded(hash: str, path: Path) -> None:
global _dns_failed global _dns_failed
if _dns_failed: if _dns_failed:

View File

@ -1,11 +1,15 @@
import base64 import base64
import filecmp import filecmp
from itertools import zip_longest from itertools import zip_longest
from pathlib import Path
from typing import Dict, List
from dominate.tags import a, i, img, table, td, th, tr from dominate.tags import a, i, img, table, td, th, tr
def report_links(tests, reports_path, actual_hashes=None): def report_links(
tests: List[Path], reports_path: Path, actual_hashes: Dict[str, str] = None
) -> None:
if actual_hashes is None: if actual_hashes is None:
actual_hashes = {} actual_hashes = {}
@ -21,12 +25,12 @@ def report_links(tests, reports_path, actual_hashes=None):
td(a(test.name, href=path)) td(a(test.name, href=path))
def write(fixture_test_path, doc, filename): def write(fixture_test_path: Path, doc, filename: str) -> Path:
(fixture_test_path / filename).write_text(doc.render()) (fixture_test_path / filename).write_text(doc.render())
return fixture_test_path / filename return fixture_test_path / filename
def image(src): def image(src: Path) -> None:
with td(): with td():
if src: if src:
# open image file # open image file
@ -41,7 +45,7 @@ def image(src):
i("missing") i("missing")
def diff_table(left_screens, right_screens): def diff_table(left_screens: List[Path], right_screens: List[Path]) -> None:
for left, right in zip_longest(left_screens, right_screens): for left, right in zip_longest(left_screens, right_screens):
if left and right and filecmp.cmp(right, left): if left and right and filecmp.cmp(right, left):
background = "white" background = "white"

View File

@ -2,6 +2,7 @@ import shutil
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple
import dominate import dominate
from dominate.tags import br, h1, h2, hr, i, p, table, td, th, tr from dominate.tags import br, h1, h2, hr, i, p, table, td, th, tr
@ -14,7 +15,7 @@ REPORTS_PATH = Path(__file__).resolve().parent / "reports" / "master_diff"
RECORDED_SCREENS_PATH = Path(__file__).resolve().parent.parent / "screens" RECORDED_SCREENS_PATH = Path(__file__).resolve().parent.parent / "screens"
def get_diff(): def get_diff() -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]:
master = download.fetch_fixtures_master() master = download.fetch_fixtures_master()
current = download.fetch_fixtures_current() current = download.fetch_fixtures_current()
@ -34,7 +35,7 @@ def get_diff():
return removed, added, diff return removed, added, diff
def removed(screens_path, test_name): def removed(screens_path: Path, test_name: str) -> Path:
doc = dominate.document(title=test_name) doc = dominate.document(title=test_name)
screens = sorted(screens_path.iterdir()) screens = sorted(screens_path.iterdir())
@ -57,7 +58,7 @@ def removed(screens_path, test_name):
return html.write(REPORTS_PATH / "removed", doc, test_name + ".html") return html.write(REPORTS_PATH / "removed", doc, test_name + ".html")
def added(screens_path, test_name): def added(screens_path: Path, test_name: str) -> Path:
doc = dominate.document(title=test_name) doc = dominate.document(title=test_name)
screens = sorted(screens_path.iterdir()) screens = sorted(screens_path.iterdir())
@ -81,8 +82,12 @@ def added(screens_path, test_name):
def diff( def diff(
master_screens_path, current_screens_path, test_name, master_hash, current_hash master_screens_path: Path,
): current_screens_path: Path,
test_name: str,
master_hash: str,
current_hash: str,
) -> Path:
doc = dominate.document(title=test_name) doc = dominate.document(title=test_name)
master_screens = sorted(master_screens_path.iterdir()) master_screens = sorted(master_screens_path.iterdir())
current_screens = sorted(current_screens_path.iterdir()) current_screens = sorted(current_screens_path.iterdir())
@ -109,7 +114,7 @@ def diff(
return html.write(REPORTS_PATH / "diff", doc, test_name + ".html") return html.write(REPORTS_PATH / "diff", doc, test_name + ".html")
def index(): def index() -> Path:
removed = list((REPORTS_PATH / "removed").iterdir()) removed = list((REPORTS_PATH / "removed").iterdir())
added = list((REPORTS_PATH / "added").iterdir()) added = list((REPORTS_PATH / "added").iterdir())
diff = list((REPORTS_PATH / "diff").iterdir()) diff = list((REPORTS_PATH / "diff").iterdir())
@ -140,7 +145,7 @@ def index():
return html.write(REPORTS_PATH, doc, "index.html") return html.write(REPORTS_PATH, doc, "index.html")
def create_dirs(): def create_dirs() -> None:
# delete the reports dir to clear previous entries and create folders # delete the reports dir to clear previous entries and create folders
shutil.rmtree(REPORTS_PATH, ignore_errors=True) shutil.rmtree(REPORTS_PATH, ignore_errors=True)
REPORTS_PATH.mkdir() REPORTS_PATH.mkdir()
@ -149,7 +154,7 @@ def create_dirs():
(REPORTS_PATH / "diff").mkdir() (REPORTS_PATH / "diff").mkdir()
def create_reports(): def create_reports() -> None:
removed_tests, added_tests, diff_tests = get_diff() removed_tests, added_tests, diff_tests = get_diff()
@contextmanager @contextmanager

View File

@ -2,6 +2,7 @@ import shutil
from datetime import datetime from datetime import datetime
from distutils.dir_util import copy_tree from distutils.dir_util import copy_tree
from pathlib import Path from pathlib import Path
from typing import Dict
import dominate import dominate
import dominate.tags as t import dominate.tags as t
@ -16,10 +17,12 @@ REPORTS_PATH = HERE / "reports" / "test"
STYLE = (HERE / "testreport.css").read_text() STYLE = (HERE / "testreport.css").read_text()
SCRIPT = (HERE / "testreport.js").read_text() SCRIPT = (HERE / "testreport.js").read_text()
ACTUAL_HASHES = {} ACTUAL_HASHES: Dict[str, str] = {}
def document(title, actual_hash=None, index=False): def document(
title: str, actual_hash: str = None, index: bool = False
) -> dominate.document:
doc = dominate.document(title=title) doc = dominate.document(title=title)
style = t.style() style = t.style()
style.add_raw_string(STYLE) style.add_raw_string(STYLE)
@ -36,7 +39,7 @@ def document(title, actual_hash=None, index=False):
return doc return doc
def _header(test_name, expected_hash, actual_hash): def _header(test_name: str, expected_hash: str, actual_hash: str) -> None:
h1(test_name) h1(test_name)
with div(): with div():
if actual_hash == expected_hash: if actual_hash == expected_hash:
@ -54,7 +57,7 @@ def _header(test_name, expected_hash, actual_hash):
hr() hr()
def clear_dir(): def clear_dir() -> None:
# delete and create the reports dir to clear previous entries # delete and create the reports dir to clear previous entries
shutil.rmtree(REPORTS_PATH, ignore_errors=True) shutil.rmtree(REPORTS_PATH, ignore_errors=True)
REPORTS_PATH.mkdir() REPORTS_PATH.mkdir()
@ -62,7 +65,7 @@ def clear_dir():
(REPORTS_PATH / "passed").mkdir() (REPORTS_PATH / "passed").mkdir()
def index(): def index() -> Path:
passed_tests = list((REPORTS_PATH / "passed").iterdir()) passed_tests = list((REPORTS_PATH / "passed").iterdir())
failed_tests = list((REPORTS_PATH / "failed").iterdir()) failed_tests = list((REPORTS_PATH / "failed").iterdir())
@ -104,7 +107,9 @@ def index():
return html.write(REPORTS_PATH, doc, "index.html") return html.write(REPORTS_PATH, doc, "index.html")
def failed(fixture_test_path, test_name, actual_hash, expected_hash): def failed(
fixture_test_path: Path, test_name: str, actual_hash: str, expected_hash: str
) -> Path:
ACTUAL_HASHES[test_name] = actual_hash ACTUAL_HASHES[test_name] = actual_hash
doc = document(title=test_name, actual_hash=actual_hash) doc = document(title=test_name, actual_hash=actual_hash)
@ -147,7 +152,7 @@ def failed(fixture_test_path, test_name, actual_hash, expected_hash):
return html.write(REPORTS_PATH / "failed", doc, test_name + ".html") return html.write(REPORTS_PATH / "failed", doc, test_name + ".html")
def passed(fixture_test_path, test_name, actual_hash): def passed(fixture_test_path: Path, test_name: str, actual_hash: str) -> Path:
copy_tree(str(fixture_test_path / "actual"), str(fixture_test_path / "recorded")) copy_tree(str(fixture_test_path / "actual"), str(fixture_test_path / "recorded"))
doc = document(title=test_name) doc = document(title=test_name)

View File

@ -15,8 +15,10 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os import os
from typing import List, Tuple
import pytest import pytest
from _pytest.mark.structures import MarkDecorator
from ..emulators import ALL_TAGS, LOCAL_BUILD_PATHS from ..emulators import ALL_TAGS, LOCAL_BUILD_PATHS
@ -44,7 +46,11 @@ core_only = pytest.mark.skipif(
) )
def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0, 0)): def for_all(
*args,
legacy_minimum_version: Tuple[int, int, int] = (1, 0, 0),
core_minimum_version: Tuple[int, int, int] = (2, 0, 0)
) -> "MarkDecorator":
"""Parametrizing decorator for test cases. """Parametrizing decorator for test cases.
Usage example: Usage example:
@ -98,7 +104,7 @@ def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0,
return pytest.mark.parametrize("gen, tag", all_params) return pytest.mark.parametrize("gen, tag", all_params)
def for_tags(*args): def for_tags(*args: Tuple[str, List[str]]) -> "MarkDecorator":
enabled_gens = SELECTED_GENS or ("core", "legacy") enabled_gens = SELECTED_GENS or ("core", "legacy")
return pytest.mark.parametrize( return pytest.mark.parametrize(
"gen, tags", [(gen, tags) for gen, tags in args if gen in enabled_gens] "gen, tags", [(gen, tags) for gen, tags in args if gen in enabled_gens]