mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-12 00:10:58 +00:00
chore(tests): add type hints to helper test functions
This commit is contained in:
parent
5d76144ef5
commit
c77e18d77c
@ -18,6 +18,7 @@ import hashlib
|
||||
import hmac
|
||||
import struct
|
||||
from copy import copy
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import ecdsa
|
||||
from ecdsa.curves import SECP256k1
|
||||
@ -27,7 +28,7 @@ from ecdsa.util import number_to_string, string_to_number
|
||||
from trezorlib import messages, tools
|
||||
|
||||
|
||||
def point_to_pubkey(point):
|
||||
def point_to_pubkey(point: Point) -> bytes:
|
||||
order = SECP256k1.order
|
||||
x_str = number_to_string(point.x(), order)
|
||||
y_str = number_to_string(point.y(), order)
|
||||
@ -35,14 +36,14 @@ def point_to_pubkey(point):
|
||||
return struct.pack("B", (vk[63] & 1) + 2) + vk[0:32] # To compressed key
|
||||
|
||||
|
||||
def sec_to_public_pair(pubkey):
|
||||
def sec_to_public_pair(pubkey: bytes) -> Tuple[int, Any]:
|
||||
"""Convert a public key in sec binary format to a public pair."""
|
||||
x = string_to_number(pubkey[1:33])
|
||||
sec0 = pubkey[:1]
|
||||
if sec0 not in (b"\2", b"\3"):
|
||||
raise ValueError("Compressed pubkey expected")
|
||||
|
||||
def public_pair_for_x(generator, x, is_even):
|
||||
def public_pair_for_x(generator, x: int, is_even: bool) -> Tuple[int, Any]:
|
||||
curve = generator.curve()
|
||||
p = curve.p()
|
||||
alpha = (pow(x, 3, p) + curve.a() * x + curve.b()) % p
|
||||
@ -56,15 +57,15 @@ def sec_to_public_pair(pubkey):
|
||||
)
|
||||
|
||||
|
||||
def fingerprint(pubkey):
|
||||
def fingerprint(pubkey: bytes) -> int:
|
||||
return string_to_number(tools.hash_160(pubkey)[:4])
|
||||
|
||||
|
||||
def get_address(public_node, address_type):
|
||||
def get_address(public_node: messages.HDNodeType, address_type: int) -> str:
|
||||
return tools.public_key_to_bc_address(public_node.public_key, address_type)
|
||||
|
||||
|
||||
def public_ckd(public_node, n):
|
||||
def public_ckd(public_node: messages.HDNodeType, n: List[int]):
|
||||
if not isinstance(n, list):
|
||||
raise ValueError("Parameter must be a list")
|
||||
|
||||
@ -76,7 +77,7 @@ def public_ckd(public_node, n):
|
||||
return node
|
||||
|
||||
|
||||
def get_subnode(node, i):
|
||||
def get_subnode(node: messages.HDNodeType, i: int) -> messages.HDNodeType:
|
||||
# Public Child key derivation (CKD) algorithm of BIP32
|
||||
i_as_bytes = struct.pack(">L", i)
|
||||
|
||||
@ -108,7 +109,7 @@ def get_subnode(node, i):
|
||||
)
|
||||
|
||||
|
||||
def serialize(node, version=0x0488B21E):
|
||||
def serialize(node: messages.HDNodeType, version: int = 0x0488B21E) -> str:
|
||||
s = b""
|
||||
s += struct.pack(">I", version)
|
||||
s += struct.pack(">B", node.depth)
|
||||
@ -123,7 +124,7 @@ def serialize(node, version=0x0488B21E):
|
||||
return tools.b58encode(s)
|
||||
|
||||
|
||||
def deserialize(xpub):
|
||||
def deserialize(xpub: str) -> messages.HDNodeType:
|
||||
data = tools.b58decode(xpub, None)
|
||||
|
||||
if tools.btc_hash(data[:-4])[:4] != data[-4:]:
|
||||
|
@ -1,8 +1,10 @@
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
DISPLAY_WIDTH = 240
|
||||
DISPLAY_HEIGHT = 240
|
||||
|
||||
|
||||
def grid(dim, grid_cells, cell):
|
||||
def grid(dim: int, grid_cells: int, cell: int) -> int:
|
||||
step = dim // grid_cells
|
||||
ofs = step // 2
|
||||
return cell * step + ofs
|
||||
@ -34,15 +36,15 @@ RESET_WORD_CHECK = [
|
||||
BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
|
||||
|
||||
|
||||
def grid35(x, y):
|
||||
def grid35(x: int, y: int) -> Tuple[int, int]:
|
||||
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y)
|
||||
|
||||
|
||||
def grid34(x, y):
|
||||
def grid34(x: int, y: int) -> Tuple[int, int]:
|
||||
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y)
|
||||
|
||||
|
||||
def type_word(word):
|
||||
def type_word(word: str) -> Iterator[Tuple[int, int]]:
|
||||
for l in word:
|
||||
idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters)
|
||||
grid_x = idx % 3
|
||||
|
@ -18,12 +18,18 @@ import json
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generator, List, Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from trezorlib import btc, tools
|
||||
from trezorlib.messages import ButtonRequestType as B
|
||||
from trezorlib.messages import ButtonRequestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorlib.debuglink import DebugLink, TrezorClientDebugLink as Client
|
||||
from trezorlib.messages import ButtonRequest
|
||||
from _pytest.mark.structures import MarkDecorator
|
||||
|
||||
# fmt: off
|
||||
# 1 2 3 4 5 6 7 8 9 10 11 12
|
||||
@ -56,7 +62,7 @@ COMMON_FIXTURES_DIR = (
|
||||
)
|
||||
|
||||
|
||||
def parametrize_using_common_fixtures(*paths):
|
||||
def parametrize_using_common_fixtures(*paths: str) -> "MarkDecorator":
|
||||
fixtures = []
|
||||
for path in paths:
|
||||
fixtures.append(json.loads((COMMON_FIXTURES_DIR / path).read_text()))
|
||||
@ -85,7 +91,9 @@ def parametrize_using_common_fixtures(*paths):
|
||||
return pytest.mark.parametrize("parameters, result", tests)
|
||||
|
||||
|
||||
def generate_entropy(strength, internal_entropy, external_entropy):
|
||||
def generate_entropy(
|
||||
strength: int, internal_entropy: bytes, external_entropy: bytes
|
||||
) -> bytes:
|
||||
"""
|
||||
strength - length of produced seed. One of 128, 192, 256
|
||||
random - binary stream of random data from external HRNG
|
||||
@ -116,7 +124,12 @@ def generate_entropy(strength, internal_entropy, external_entropy):
|
||||
return entropy_stripped
|
||||
|
||||
|
||||
def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
||||
def recovery_enter_shares(
|
||||
debug: "DebugLink",
|
||||
shares: List[str],
|
||||
groups: bool = False,
|
||||
click_info: bool = False,
|
||||
) -> Generator[None, "ButtonRequest", None]:
|
||||
"""Perform the recovery flow for a set of Shamir shares.
|
||||
|
||||
For use in an input flow function.
|
||||
@ -134,7 +147,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
||||
debug.press_yes()
|
||||
# Input word number
|
||||
br = yield
|
||||
assert br.code == B.MnemonicWordCount
|
||||
assert br.code == ButtonRequestType.MnemonicWordCount
|
||||
debug.input(str(word_count))
|
||||
# Homescreen - proceed to share entry
|
||||
yield
|
||||
@ -142,7 +155,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
||||
# Enter shares
|
||||
for share in shares:
|
||||
br = yield
|
||||
assert br.code == B.MnemonicInput
|
||||
assert br.code == ButtonRequestType.MnemonicInput
|
||||
# Enter mnemonic words
|
||||
for word in share.split(" "):
|
||||
debug.input(word)
|
||||
@ -167,7 +180,9 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
||||
debug.press_yes()
|
||||
|
||||
|
||||
def click_through(debug, screens, code=None):
|
||||
def click_through(
|
||||
debug: "DebugLink", screens: int, code: ButtonRequestType = None
|
||||
) -> Generator[None, "ButtonRequest", None]:
|
||||
"""Click through N dialog screens.
|
||||
|
||||
For use in an input flow function.
|
||||
@ -178,7 +193,7 @@ def click_through(debug, screens, code=None):
|
||||
# 2. Backup your seed
|
||||
# 3. Confirm warning
|
||||
# 4. Shares info
|
||||
yield from click_through(client.debug, screens=4, code=B.ResetDevice)
|
||||
yield from click_through(client.debug, screens=4, code=ButtonRequestType.ResetDevice)
|
||||
"""
|
||||
for _ in range(screens):
|
||||
received = yield
|
||||
@ -187,7 +202,9 @@ def click_through(debug, screens, code=None):
|
||||
debug.press_yes()
|
||||
|
||||
|
||||
def read_and_confirm_mnemonic(debug, choose_wrong=False):
|
||||
def read_and_confirm_mnemonic(
|
||||
debug: "DebugLink", choose_wrong: bool = False
|
||||
) -> Generator[None, "ButtonRequest", Optional[str]]:
|
||||
"""Read a given number of mnemonic words from Trezor T screen and correctly
|
||||
answer confirmation questions. Return the full mnemonic.
|
||||
|
||||
@ -201,6 +218,7 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
|
||||
"""
|
||||
mnemonic = []
|
||||
br = yield
|
||||
assert br.pages is not None
|
||||
for _ in range(br.pages - 1):
|
||||
mnemonic.extend(debug.read_reset_word().split())
|
||||
debug.swipe_up(wait=True)
|
||||
@ -221,13 +239,15 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
|
||||
return " ".join(mnemonic)
|
||||
|
||||
|
||||
def get_test_address(client):
|
||||
def get_test_address(client: "Client") -> str:
|
||||
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
|
||||
protected call, or to identify the root secret (seed+passphrase)"""
|
||||
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
|
||||
|
||||
|
||||
def assert_tx_matches(serialized_tx: bytes, hash_link: str, tx_hex: str = None) -> None:
|
||||
def assert_tx_matches(
|
||||
serialized_tx: bytes, hash_link: str, tx_hex: str = None
|
||||
) -> None:
|
||||
"""Verifies if a transaction is correctly formed."""
|
||||
hash_str = hash_link.split("/")[-1]
|
||||
assert tools.tx_hash(serialized_tx).hex() == hash_str
|
||||
|
@ -15,6 +15,7 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
@ -27,12 +28,17 @@ from . import ui_tests
|
||||
from .device_handler import BackgroundDeviceHandler
|
||||
from .ui_tests.reporting import testreport
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _pytest.config import Config
|
||||
from _pytest.config.argparsing import Parser
|
||||
from _pytest.terminal import TerminalReporter
|
||||
|
||||
# So that we see details of failed asserts from this module
|
||||
pytest.register_assert_rewrite("tests.common")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _raw_client(request):
|
||||
def _raw_client(request: pytest.FixtureRequest) -> TrezorClientDebugLink:
|
||||
path = os.environ.get("TREZOR_PATH")
|
||||
interact = int(os.environ.get("INTERACT", 0))
|
||||
if path:
|
||||
@ -56,7 +62,9 @@ def _raw_client(request):
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(request, _raw_client):
|
||||
def client(
|
||||
request: pytest.FixtureRequest, _raw_client: TrezorClientDebugLink
|
||||
) -> TrezorClientDebugLink:
|
||||
"""Client fixture.
|
||||
|
||||
Every test function that requires a client instance will get it from here.
|
||||
@ -156,20 +164,20 @@ def client(request, _raw_client):
|
||||
_raw_client.close()
|
||||
|
||||
|
||||
def pytest_sessionstart(session):
|
||||
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||
ui_tests.read_fixtures()
|
||||
if session.config.getoption("ui") == "test":
|
||||
testreport.clear_dir()
|
||||
|
||||
|
||||
def _should_write_ui_report(exitstatus):
|
||||
def _should_write_ui_report(exitstatus: pytest.ExitCode) -> bool:
|
||||
# generate UI report and check missing only if pytest is exitting cleanly
|
||||
# I.e., the test suite passed or failed (as opposed to ctrl+c break, internal error,
|
||||
# etc.)
|
||||
return exitstatus in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED)
|
||||
|
||||
|
||||
def pytest_sessionfinish(session, exitstatus):
|
||||
def pytest_sessionfinish(session: pytest.Session, exitstatus: pytest.ExitCode) -> None:
|
||||
if not _should_write_ui_report(exitstatus):
|
||||
return
|
||||
|
||||
@ -183,7 +191,9 @@ def pytest_sessionfinish(session, exitstatus):
|
||||
ui_tests.write_fixtures(missing)
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
||||
def pytest_terminal_summary(
|
||||
terminalreporter: "TerminalReporter", exitstatus: pytest.ExitCode, config: "Config"
|
||||
) -> None:
|
||||
println = terminalreporter.write_line
|
||||
println("")
|
||||
|
||||
@ -213,7 +223,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
||||
println("")
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
def pytest_addoption(parser: "Parser") -> None:
|
||||
parser.addoption(
|
||||
"--ui",
|
||||
action="store",
|
||||
@ -229,7 +239,7 @@ def pytest_addoption(parser):
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
def pytest_configure(config: "Config") -> None:
|
||||
"""Called at testsuite setup time.
|
||||
|
||||
Registers known markers, enables verbose output if requested.
|
||||
@ -253,7 +263,7 @@ def pytest_configure(config):
|
||||
log.enable_debug_output()
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
def pytest_runtest_setup(item: pytest.Item) -> None:
|
||||
"""Called for each test item (class, individual tests).
|
||||
|
||||
Ensures that altcoin tests are skipped, and that no test is skipped on
|
||||
@ -267,7 +277,7 @@ def pytest_runtest_setup(item):
|
||||
pytest.skip("Skipping altcoin test")
|
||||
|
||||
|
||||
def pytest_runtest_teardown(item):
|
||||
def pytest_runtest_teardown(item: pytest.Item) -> None:
|
||||
"""Called after a test item finishes.
|
||||
|
||||
Dumps the current UI test report HTML.
|
||||
@ -277,7 +287,7 @@ def pytest_runtest_teardown(item):
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
def pytest_runtest_makereport(item: pytest.Item, call) -> None:
|
||||
# Make test results available in fixtures.
|
||||
# See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
||||
# The device_handler fixture uses this as 'request.node.rep_call.passed' attribute,
|
||||
@ -288,7 +298,9 @@ def pytest_runtest_makereport(item, call):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device_handler(client, request):
|
||||
def device_handler(
|
||||
client: TrezorClientDebugLink, request: pytest.FixtureRequest
|
||||
) -> None:
|
||||
device_handler = BackgroundDeviceHandler(client)
|
||||
yield device_handler
|
||||
|
||||
@ -299,5 +311,5 @@ def device_handler(client, request):
|
||||
|
||||
# if test finished, make sure all background tasks are done
|
||||
finalized_ok = device_handler.check_finalize()
|
||||
if request.node.rep_call.passed and not finalized_ok:
|
||||
if request.node.rep_call.passed and not finalized_ok: # type: ignore [rep_call must exist]
|
||||
raise RuntimeError("Test did not check result of background task")
|
||||
|
@ -1,8 +1,15 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezorlib.client import PASSPHRASE_ON_DEVICE
|
||||
from trezorlib.transport import udp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorlib.messages import Features
|
||||
from trezorlib.debuglink import TrezorClientDebugLink, DebugLink
|
||||
from trezorlib._internal.emulator import Emulator
|
||||
|
||||
|
||||
udp.SOCKET_TIMEOUT = 0.1
|
||||
|
||||
|
||||
@ -26,13 +33,13 @@ class NullUI:
|
||||
class BackgroundDeviceHandler:
|
||||
_pool = ThreadPoolExecutor()
|
||||
|
||||
def __init__(self, client):
|
||||
def __init__(self, client: "TrezorClientDebugLink") -> None:
|
||||
self._configure_client(client)
|
||||
self.task = None
|
||||
|
||||
def _configure_client(self, client):
|
||||
def _configure_client(self, client: "TrezorClientDebugLink") -> None:
|
||||
self.client = client
|
||||
self.client.ui = NullUI
|
||||
self.client.ui = NullUI # type: ignore [NullUI is OK UI]
|
||||
self.client.watch_layout(True)
|
||||
|
||||
def run(self, function, *args, **kwargs):
|
||||
@ -40,7 +47,7 @@ class BackgroundDeviceHandler:
|
||||
raise RuntimeError("Wait for previous task first")
|
||||
self.task = self._pool.submit(function, self.client, *args, **kwargs)
|
||||
|
||||
def kill_task(self):
|
||||
def kill_task(self) -> None:
|
||||
if self.task is not None:
|
||||
# Force close the client, which should raise an exception in a client
|
||||
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
||||
@ -53,11 +60,11 @@ class BackgroundDeviceHandler:
|
||||
pass
|
||||
self.task = None
|
||||
|
||||
def restart(self, emulator):
|
||||
def restart(self, emulator: "Emulator"):
|
||||
# TODO handle actual restart as well
|
||||
self.kill_task()
|
||||
emulator.restart()
|
||||
self._configure_client(emulator.client)
|
||||
self._configure_client(emulator.client) # type: ignore [client cannot be None]
|
||||
|
||||
def result(self):
|
||||
if self.task is None:
|
||||
@ -67,25 +74,25 @@ class BackgroundDeviceHandler:
|
||||
finally:
|
||||
self.task = None
|
||||
|
||||
def features(self):
|
||||
def features(self) -> "Features":
|
||||
if self.task is not None:
|
||||
raise RuntimeError("Cannot query features while task is running")
|
||||
self.client.init_device()
|
||||
return self.client.features
|
||||
|
||||
def debuglink(self):
|
||||
def debuglink(self) -> "DebugLink":
|
||||
return self.client.debug
|
||||
|
||||
def check_finalize(self):
|
||||
def check_finalize(self) -> bool:
|
||||
if self.task is not None:
|
||||
self.kill_task()
|
||||
return False
|
||||
return True
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> "BackgroundDeviceHandler":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
finalized_ok = self.check_finalize()
|
||||
if exc_type is None and not finalized_ok:
|
||||
raise RuntimeError("Exit while task is unfinished")
|
||||
|
@ -17,8 +17,9 @@
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from trezorlib._internal.emulator import CoreEmulator, LegacyEmulator
|
||||
from trezorlib._internal.emulator import CoreEmulator, Emulator, LegacyEmulator
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
BINDIR = ROOT / "tests" / "emulators"
|
||||
@ -33,18 +34,18 @@ CORE_SRC_DIR = ROOT / "core" / "src"
|
||||
ENV = {"SDL_VIDEODRIVER": "dummy"}
|
||||
|
||||
|
||||
def check_version(tag, version_tuple):
|
||||
def check_version(tag: str, version_tuple: Tuple[int, int, int]) -> None:
|
||||
if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3:
|
||||
version = ".".join(str(i) for i in version_tuple)
|
||||
if tag[1:] != version:
|
||||
raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}")
|
||||
|
||||
|
||||
def filename_from_tag(gen, tag):
|
||||
def filename_from_tag(gen: str, tag: str) -> Path:
|
||||
return BINDIR / f"trezor-emu-{gen}-{tag}"
|
||||
|
||||
|
||||
def get_tags():
|
||||
def get_tags() -> Dict[str, List[str]]:
|
||||
files = list(BINDIR.iterdir())
|
||||
if not files:
|
||||
raise ValueError(
|
||||
@ -66,7 +67,7 @@ ALL_TAGS = get_tags()
|
||||
|
||||
|
||||
class EmulatorWrapper:
|
||||
def __init__(self, gen, tag=None, storage=None):
|
||||
def __init__(self, gen: str, tag: str = None, storage: bytes = None) -> None:
|
||||
if tag is not None:
|
||||
executable = filename_from_tag(gen, tag)
|
||||
else:
|
||||
@ -96,11 +97,15 @@ class EmulatorWrapper:
|
||||
workdir=workdir,
|
||||
headless=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized gen - {gen} - only 'core' and 'legacy' supported"
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> Emulator:
|
||||
self.emulator.start()
|
||||
return self.emulator
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
self.emulator.stop()
|
||||
self.profile_dir.cleanup()
|
||||
|
@ -5,9 +5,9 @@ import multiprocessing
|
||||
import os
|
||||
import posixpath
|
||||
import time
|
||||
import urllib
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from urllib.parse import unquote
|
||||
|
||||
import click
|
||||
|
||||
@ -16,16 +16,16 @@ TEST_RESULT_PATH = ROOT / "tests" / "ui_tests" / "reporting" / "reports" / "test
|
||||
|
||||
|
||||
class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def end_headers(self):
|
||||
def end_headers(self) -> None:
|
||||
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||
self.send_header("Pragma", "no-cache")
|
||||
self.send_header("Expires", "0")
|
||||
return super().end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
def log_message(self, format, *args) -> None:
|
||||
pass
|
||||
|
||||
def translate_path(self, path):
|
||||
def translate_path(self, path: str) -> str:
|
||||
# XXX
|
||||
# Copy-pasted from Python 3.8 BaseHTTPRequestHandler so that we can inject
|
||||
# the `directory` parameter.
|
||||
@ -36,9 +36,9 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
# Don't forget explicit trailing slash when normalizing. Issue17324
|
||||
trailing_slash = path.rstrip().endswith("/")
|
||||
try:
|
||||
path = urllib.parse.unquote(path, errors="surrogatepass")
|
||||
path = unquote(path, errors="surrogatepass")
|
||||
except UnicodeDecodeError:
|
||||
path = urllib.parse.unquote(path)
|
||||
path = unquote(path)
|
||||
path = posixpath.normpath(path)
|
||||
words = path.split("/")
|
||||
words = filter(None, words)
|
||||
@ -53,13 +53,13 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
return path
|
||||
|
||||
|
||||
def launch_http_server(port):
|
||||
http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port)
|
||||
def launch_http_server(port: int) -> None:
|
||||
http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port) # type: ignore [test is defined]
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("-p", "--port", type=int, default=8000)
|
||||
def main(port):
|
||||
def main(port: int):
|
||||
httpd = multiprocessing.Process(target=launch_http_server, args=(port,))
|
||||
httpd.start()
|
||||
time.sleep(0.5)
|
||||
|
@ -122,7 +122,7 @@ def cli(tx, coin_name):
|
||||
tx_dict = protobuf.to_dict(tx_proto)
|
||||
tx_json = json.dumps(tx_dict, sort_keys=True, indent=2) + "\n"
|
||||
except Exception as e:
|
||||
raise click.ClickException(e) from e
|
||||
raise click.ClickException(str(e)) from e
|
||||
|
||||
cache_dir = CACHE_PATH / coin_name
|
||||
if not cache_dir.exists():
|
||||
|
@ -4,23 +4,26 @@ import re
|
||||
import shutil
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, Set
|
||||
|
||||
import pytest
|
||||
from _pytest.outcomes import Failed
|
||||
from PIL import Image
|
||||
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
|
||||
from .reporting import testreport
|
||||
|
||||
UI_TESTS_DIR = Path(__file__).resolve().parent
|
||||
SCREENS_DIR = UI_TESTS_DIR / "screens"
|
||||
HASH_FILE = UI_TESTS_DIR / "fixtures.json"
|
||||
SUGGESTION_FILE = UI_TESTS_DIR / "fixtures.suggestion.json"
|
||||
FILE_HASHES = {}
|
||||
ACTUAL_HASHES = {}
|
||||
PROCESSED = set()
|
||||
FILE_HASHES: Dict[str, str] = {}
|
||||
ACTUAL_HASHES: Dict[str, str] = {}
|
||||
PROCESSED: Set[str] = set()
|
||||
|
||||
|
||||
def get_test_name(node_id):
|
||||
def get_test_name(node_id: str) -> str:
|
||||
# Test item name is usually function name, but when parametrization is used,
|
||||
# parameters are also part of the name. Some functions have very long parameter
|
||||
# names (tx hashes etc) that run out of maximum allowable filename length, so
|
||||
@ -34,14 +37,14 @@ def get_test_name(node_id):
|
||||
return new_name[:91] + "-" + hashlib.sha256(new_name.encode()).hexdigest()[:8]
|
||||
|
||||
|
||||
def _process_recorded(screen_path, test_name):
|
||||
def _process_recorded(screen_path: Path, test_name: str) -> None:
|
||||
# calculate hash
|
||||
FILE_HASHES[test_name] = _hash_files(screen_path)
|
||||
_rename_records(screen_path)
|
||||
PROCESSED.add(test_name)
|
||||
|
||||
|
||||
def _rename_records(screen_path):
|
||||
def _rename_records(screen_path: Path) -> None:
|
||||
# rename screenshots
|
||||
for index, record in enumerate(sorted(screen_path.iterdir())):
|
||||
record.replace(screen_path / f"{index:08}.png")
|
||||
@ -65,7 +68,7 @@ def _get_bytes_from_png(png_file: str) -> bytes:
|
||||
return Image.open(png_file).tobytes()
|
||||
|
||||
|
||||
def _process_tested(fixture_test_path, test_name):
|
||||
def _process_tested(fixture_test_path: Path, test_name: str) -> None:
|
||||
PROCESSED.add(test_name)
|
||||
|
||||
actual_path = fixture_test_path / "actual"
|
||||
@ -79,6 +82,7 @@ def _process_tested(fixture_test_path, test_name):
|
||||
pytest.fail(f"Hash of {test_name} not found in fixtures.json")
|
||||
|
||||
if actual_hash != expected_hash:
|
||||
assert expected_hash is not None
|
||||
file_path = testreport.failed(
|
||||
fixture_test_path, test_name, actual_hash, expected_hash
|
||||
)
|
||||
@ -94,7 +98,9 @@ def _process_tested(fixture_test_path, test_name):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def screen_recording(client, request):
|
||||
def screen_recording(
|
||||
client: Client, request: pytest.FixtureRequest
|
||||
) -> Generator[None, None, None]:
|
||||
test_ui = request.config.getoption("ui")
|
||||
test_name = get_test_name(request.node.nodeid)
|
||||
screens_test_path = SCREENS_DIR / test_name
|
||||
@ -126,26 +132,26 @@ def screen_recording(client, request):
|
||||
_process_tested(screens_test_path, test_name)
|
||||
|
||||
|
||||
def list_missing():
|
||||
def list_missing() -> Set[str]:
|
||||
return set(FILE_HASHES.keys()) - PROCESSED
|
||||
|
||||
|
||||
def read_fixtures():
|
||||
def read_fixtures() -> None:
|
||||
if not HASH_FILE.exists():
|
||||
raise ValueError("File fixtures.json not found.")
|
||||
global FILE_HASHES
|
||||
FILE_HASHES = json.loads(HASH_FILE.read_text())
|
||||
|
||||
|
||||
def write_fixtures(remove_missing: bool):
|
||||
def write_fixtures(remove_missing: bool) -> None:
|
||||
HASH_FILE.write_text(_get_fixtures_content(FILE_HASHES, remove_missing))
|
||||
|
||||
|
||||
def write_fixtures_suggestion(remove_missing: bool):
|
||||
def write_fixtures_suggestion(remove_missing: bool) -> None:
|
||||
SUGGESTION_FILE.write_text(_get_fixtures_content(ACTUAL_HASHES, remove_missing))
|
||||
|
||||
|
||||
def _get_fixtures_content(fixtures: dict, remove_missing: bool):
|
||||
def _get_fixtures_content(fixtures: Dict[str, str], remove_missing: bool) -> str:
|
||||
if remove_missing:
|
||||
fixtures = {i: fixtures[i] for i in PROCESSED}
|
||||
else:
|
||||
@ -154,7 +160,7 @@ def _get_fixtures_content(fixtures: dict, remove_missing: bool):
|
||||
return json.dumps(fixtures, indent="", sort_keys=True) + "\n"
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
read_fixtures()
|
||||
for record in SCREENS_DIR.iterdir():
|
||||
if not (record / "actual").exists():
|
||||
|
@ -14,7 +14,7 @@ FIXTURES_CURRENT = Path(__file__).resolve().parent.parent / "fixtures.json"
|
||||
_dns_failed = False
|
||||
|
||||
|
||||
def fetch_recorded(hash, path):
|
||||
def fetch_recorded(hash: str, path: Path) -> None:
|
||||
global _dns_failed
|
||||
|
||||
if _dns_failed:
|
||||
|
@ -1,11 +1,15 @@
|
||||
import base64
|
||||
import filecmp
|
||||
from itertools import zip_longest
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from dominate.tags import a, i, img, table, td, th, tr
|
||||
|
||||
|
||||
def report_links(tests, reports_path, actual_hashes=None):
|
||||
def report_links(
|
||||
tests: List[Path], reports_path: Path, actual_hashes: Dict[str, str] = None
|
||||
) -> None:
|
||||
if actual_hashes is None:
|
||||
actual_hashes = {}
|
||||
|
||||
@ -21,12 +25,12 @@ def report_links(tests, reports_path, actual_hashes=None):
|
||||
td(a(test.name, href=path))
|
||||
|
||||
|
||||
def write(fixture_test_path, doc, filename):
|
||||
def write(fixture_test_path: Path, doc, filename: str) -> Path:
|
||||
(fixture_test_path / filename).write_text(doc.render())
|
||||
return fixture_test_path / filename
|
||||
|
||||
|
||||
def image(src):
|
||||
def image(src: Path) -> None:
|
||||
with td():
|
||||
if src:
|
||||
# open image file
|
||||
@ -41,7 +45,7 @@ def image(src):
|
||||
i("missing")
|
||||
|
||||
|
||||
def diff_table(left_screens, right_screens):
|
||||
def diff_table(left_screens: List[Path], right_screens: List[Path]) -> None:
|
||||
for left, right in zip_longest(left_screens, right_screens):
|
||||
if left and right and filecmp.cmp(right, left):
|
||||
background = "white"
|
||||
|
@ -2,6 +2,7 @@ import shutil
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import dominate
|
||||
from dominate.tags import br, h1, h2, hr, i, p, table, td, th, tr
|
||||
@ -14,7 +15,7 @@ REPORTS_PATH = Path(__file__).resolve().parent / "reports" / "master_diff"
|
||||
RECORDED_SCREENS_PATH = Path(__file__).resolve().parent.parent / "screens"
|
||||
|
||||
|
||||
def get_diff():
|
||||
def get_diff() -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]:
|
||||
master = download.fetch_fixtures_master()
|
||||
current = download.fetch_fixtures_current()
|
||||
|
||||
@ -34,7 +35,7 @@ def get_diff():
|
||||
return removed, added, diff
|
||||
|
||||
|
||||
def removed(screens_path, test_name):
|
||||
def removed(screens_path: Path, test_name: str) -> Path:
|
||||
doc = dominate.document(title=test_name)
|
||||
screens = sorted(screens_path.iterdir())
|
||||
|
||||
@ -57,7 +58,7 @@ def removed(screens_path, test_name):
|
||||
return html.write(REPORTS_PATH / "removed", doc, test_name + ".html")
|
||||
|
||||
|
||||
def added(screens_path, test_name):
|
||||
def added(screens_path: Path, test_name: str) -> Path:
|
||||
doc = dominate.document(title=test_name)
|
||||
screens = sorted(screens_path.iterdir())
|
||||
|
||||
@ -81,8 +82,12 @@ def added(screens_path, test_name):
|
||||
|
||||
|
||||
def diff(
|
||||
master_screens_path, current_screens_path, test_name, master_hash, current_hash
|
||||
):
|
||||
master_screens_path: Path,
|
||||
current_screens_path: Path,
|
||||
test_name: str,
|
||||
master_hash: str,
|
||||
current_hash: str,
|
||||
) -> Path:
|
||||
doc = dominate.document(title=test_name)
|
||||
master_screens = sorted(master_screens_path.iterdir())
|
||||
current_screens = sorted(current_screens_path.iterdir())
|
||||
@ -109,7 +114,7 @@ def diff(
|
||||
return html.write(REPORTS_PATH / "diff", doc, test_name + ".html")
|
||||
|
||||
|
||||
def index():
|
||||
def index() -> Path:
|
||||
removed = list((REPORTS_PATH / "removed").iterdir())
|
||||
added = list((REPORTS_PATH / "added").iterdir())
|
||||
diff = list((REPORTS_PATH / "diff").iterdir())
|
||||
@ -140,7 +145,7 @@ def index():
|
||||
return html.write(REPORTS_PATH, doc, "index.html")
|
||||
|
||||
|
||||
def create_dirs():
|
||||
def create_dirs() -> None:
|
||||
# delete the reports dir to clear previous entries and create folders
|
||||
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
|
||||
REPORTS_PATH.mkdir()
|
||||
@ -149,7 +154,7 @@ def create_dirs():
|
||||
(REPORTS_PATH / "diff").mkdir()
|
||||
|
||||
|
||||
def create_reports():
|
||||
def create_reports() -> None:
|
||||
removed_tests, added_tests, diff_tests = get_diff()
|
||||
|
||||
@contextmanager
|
||||
|
@ -2,6 +2,7 @@ import shutil
|
||||
from datetime import datetime
|
||||
from distutils.dir_util import copy_tree
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import dominate
|
||||
import dominate.tags as t
|
||||
@ -16,10 +17,12 @@ REPORTS_PATH = HERE / "reports" / "test"
|
||||
STYLE = (HERE / "testreport.css").read_text()
|
||||
SCRIPT = (HERE / "testreport.js").read_text()
|
||||
|
||||
ACTUAL_HASHES = {}
|
||||
ACTUAL_HASHES: Dict[str, str] = {}
|
||||
|
||||
|
||||
def document(title, actual_hash=None, index=False):
|
||||
def document(
|
||||
title: str, actual_hash: str = None, index: bool = False
|
||||
) -> dominate.document:
|
||||
doc = dominate.document(title=title)
|
||||
style = t.style()
|
||||
style.add_raw_string(STYLE)
|
||||
@ -36,7 +39,7 @@ def document(title, actual_hash=None, index=False):
|
||||
return doc
|
||||
|
||||
|
||||
def _header(test_name, expected_hash, actual_hash):
|
||||
def _header(test_name: str, expected_hash: str, actual_hash: str) -> None:
|
||||
h1(test_name)
|
||||
with div():
|
||||
if actual_hash == expected_hash:
|
||||
@ -54,7 +57,7 @@ def _header(test_name, expected_hash, actual_hash):
|
||||
hr()
|
||||
|
||||
|
||||
def clear_dir():
|
||||
def clear_dir() -> None:
|
||||
# delete and create the reports dir to clear previous entries
|
||||
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
|
||||
REPORTS_PATH.mkdir()
|
||||
@ -62,7 +65,7 @@ def clear_dir():
|
||||
(REPORTS_PATH / "passed").mkdir()
|
||||
|
||||
|
||||
def index():
|
||||
def index() -> Path:
|
||||
passed_tests = list((REPORTS_PATH / "passed").iterdir())
|
||||
failed_tests = list((REPORTS_PATH / "failed").iterdir())
|
||||
|
||||
@ -104,7 +107,9 @@ def index():
|
||||
return html.write(REPORTS_PATH, doc, "index.html")
|
||||
|
||||
|
||||
def failed(fixture_test_path, test_name, actual_hash, expected_hash):
|
||||
def failed(
|
||||
fixture_test_path: Path, test_name: str, actual_hash: str, expected_hash: str
|
||||
) -> Path:
|
||||
ACTUAL_HASHES[test_name] = actual_hash
|
||||
|
||||
doc = document(title=test_name, actual_hash=actual_hash)
|
||||
@ -147,7 +152,7 @@ def failed(fixture_test_path, test_name, actual_hash, expected_hash):
|
||||
return html.write(REPORTS_PATH / "failed", doc, test_name + ".html")
|
||||
|
||||
|
||||
def passed(fixture_test_path, test_name, actual_hash):
|
||||
def passed(fixture_test_path: Path, test_name: str, actual_hash: str) -> Path:
|
||||
copy_tree(str(fixture_test_path / "actual"), str(fixture_test_path / "recorded"))
|
||||
|
||||
doc = document(title=test_name)
|
||||
|
@ -15,8 +15,10 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from _pytest.mark.structures import MarkDecorator
|
||||
|
||||
from ..emulators import ALL_TAGS, LOCAL_BUILD_PATHS
|
||||
|
||||
@ -44,7 +46,11 @@ core_only = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0, 0)):
|
||||
def for_all(
|
||||
*args,
|
||||
legacy_minimum_version: Tuple[int, int, int] = (1, 0, 0),
|
||||
core_minimum_version: Tuple[int, int, int] = (2, 0, 0)
|
||||
) -> "MarkDecorator":
|
||||
"""Parametrizing decorator for test cases.
|
||||
|
||||
Usage example:
|
||||
@ -98,7 +104,7 @@ def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0,
|
||||
return pytest.mark.parametrize("gen, tag", all_params)
|
||||
|
||||
|
||||
def for_tags(*args):
|
||||
def for_tags(*args: Tuple[str, List[str]]) -> "MarkDecorator":
|
||||
enabled_gens = SELECTED_GENS or ("core", "legacy")
|
||||
return pytest.mark.parametrize(
|
||||
"gen, tags", [(gen, tags) for gen, tags in args if gen in enabled_gens]
|
||||
|
Loading…
Reference in New Issue
Block a user