diff --git a/tests/bip32.py b/tests/bip32.py index cfa833c6a..6fa08d6c2 100644 --- a/tests/bip32.py +++ b/tests/bip32.py @@ -18,6 +18,7 @@ import hashlib import hmac import struct from copy import copy +from typing import Any, List, Tuple import ecdsa from ecdsa.curves import SECP256k1 @@ -27,7 +28,7 @@ from ecdsa.util import number_to_string, string_to_number from trezorlib import messages, tools -def point_to_pubkey(point): +def point_to_pubkey(point: Point) -> bytes: order = SECP256k1.order x_str = number_to_string(point.x(), order) y_str = number_to_string(point.y(), order) @@ -35,14 +36,14 @@ def point_to_pubkey(point): return struct.pack("B", (vk[63] & 1) + 2) + vk[0:32] # To compressed key -def sec_to_public_pair(pubkey): +def sec_to_public_pair(pubkey: bytes) -> Tuple[int, Any]: """Convert a public key in sec binary format to a public pair.""" x = string_to_number(pubkey[1:33]) sec0 = pubkey[:1] if sec0 not in (b"\2", b"\3"): raise ValueError("Compressed pubkey expected") - def public_pair_for_x(generator, x, is_even): + def public_pair_for_x(generator, x: int, is_even: bool) -> Tuple[int, Any]: curve = generator.curve() p = curve.p() alpha = (pow(x, 3, p) + curve.a() * x + curve.b()) % p @@ -56,15 +57,15 @@ def sec_to_public_pair(pubkey): ) -def fingerprint(pubkey): +def fingerprint(pubkey: bytes) -> int: return string_to_number(tools.hash_160(pubkey)[:4]) -def get_address(public_node, address_type): +def get_address(public_node: messages.HDNodeType, address_type: int) -> str: return tools.public_key_to_bc_address(public_node.public_key, address_type) -def public_ckd(public_node, n): +def public_ckd(public_node: messages.HDNodeType, n: List[int]): if not isinstance(n, list): raise ValueError("Parameter must be a list") @@ -76,7 +77,7 @@ def public_ckd(public_node, n): return node -def get_subnode(node, i): +def get_subnode(node: messages.HDNodeType, i: int) -> messages.HDNodeType: # Public Child key derivation (CKD) algorithm of BIP32 i_as_bytes = struct.pack(">L", i) @@ -108,7 +109,7 @@ def get_subnode(node, i): ) -def serialize(node, version=0x0488B21E): +def serialize(node: messages.HDNodeType, version: int = 0x0488B21E) -> str: s = b"" s += struct.pack(">I", version) s += struct.pack(">B", node.depth) @@ -123,7 +124,7 @@ def serialize(node, version=0x0488B21E): return tools.b58encode(s) -def deserialize(xpub): +def deserialize(xpub: str) -> messages.HDNodeType: data = tools.b58decode(xpub, None) if tools.btc_hash(data[:-4])[:4] != data[-4:]: diff --git a/tests/buttons.py b/tests/buttons.py index e65371e16..684557901 100644 --- a/tests/buttons.py +++ b/tests/buttons.py @@ -1,8 +1,10 @@ +from typing import Iterator, Tuple + DISPLAY_WIDTH = 240 DISPLAY_HEIGHT = 240 -def grid(dim, grid_cells, cell): +def grid(dim: int, grid_cells: int, cell: int) -> int: step = dim // grid_cells ofs = step // 2 return cell * step + ofs @@ -34,15 +36,15 @@ RESET_WORD_CHECK = [ BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz") -def grid35(x, y): +def grid35(x: int, y: int) -> Tuple[int, int]: return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y) -def grid34(x, y): +def grid34(x: int, y: int) -> Tuple[int, int]: return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y) -def type_word(word): +def type_word(word: str) -> Iterator[Tuple[int, int]]: for l in word: idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters) grid_x = idx % 3 diff --git a/tests/common.py b/tests/common.py index a300b7171..c8b252912 100644 --- a/tests/common.py +++ b/tests/common.py @@ -18,12 +18,18 @@ import json import os from decimal import Decimal from pathlib import Path +from typing import TYPE_CHECKING, Generator, List, Optional import pytest import requests from trezorlib import btc, tools -from trezorlib.messages import ButtonRequestType as B +from trezorlib.messages import ButtonRequestType + +if TYPE_CHECKING: + from trezorlib.debuglink import DebugLink, TrezorClientDebugLink as Client + from trezorlib.messages import ButtonRequest + from _pytest.mark.structures import MarkDecorator # fmt: off # 1 2 3 4 5 6 7 8 9 10 11 12 @@ -56,7 +62,7 @@ COMMON_FIXTURES_DIR = ( ) -def parametrize_using_common_fixtures(*paths): +def parametrize_using_common_fixtures(*paths: str) -> "MarkDecorator": fixtures = [] for path in paths: fixtures.append(json.loads((COMMON_FIXTURES_DIR / path).read_text())) @@ -85,7 +91,9 @@ def parametrize_using_common_fixtures(*paths): return pytest.mark.parametrize("parameters, result", tests) -def generate_entropy(strength, internal_entropy, external_entropy): +def generate_entropy( + strength: int, internal_entropy: bytes, external_entropy: bytes +) -> bytes: """ strength - length of produced seed. One of 128, 192, 256 random - binary stream of random data from external HRNG @@ -116,7 +124,12 @@ def generate_entropy(strength, internal_entropy, external_entropy): return entropy_stripped -def recovery_enter_shares(debug, shares, groups=False, click_info=False): +def recovery_enter_shares( + debug: "DebugLink", + shares: List[str], + groups: bool = False, + click_info: bool = False, +) -> Generator[None, "ButtonRequest", None]: """Perform the recovery flow for a set of Shamir shares. For use in an input flow function. @@ -134,7 +147,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False): debug.press_yes() # Input word number br = yield - assert br.code == B.MnemonicWordCount + assert br.code == ButtonRequestType.MnemonicWordCount debug.input(str(word_count)) # Homescreen - proceed to share entry yield @@ -142,7 +155,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False): # Enter shares for share in shares: br = yield - assert br.code == B.MnemonicInput + assert br.code == ButtonRequestType.MnemonicInput # Enter mnemonic words for word in share.split(" "): debug.input(word) @@ -167,7 +180,9 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False): debug.press_yes() -def click_through(debug, screens, code=None): +def click_through( + debug: "DebugLink", screens: int, code: ButtonRequestType = None +) -> Generator[None, "ButtonRequest", None]: """Click through N dialog screens. For use in an input flow function. @@ -178,7 +193,7 @@ def click_through(debug, screens, code=None): # 2. Backup your seed # 3. Confirm warning # 4. Shares info - yield from click_through(client.debug, screens=4, code=B.ResetDevice) + yield from click_through(client.debug, screens=4, code=ButtonRequestType.ResetDevice) """ for _ in range(screens): received = yield @@ -187,7 +202,9 @@ def click_through(debug, screens, code=None): debug.press_yes() -def read_and_confirm_mnemonic(debug, choose_wrong=False): +def read_and_confirm_mnemonic( + debug: "DebugLink", choose_wrong: bool = False +) -> Generator[None, "ButtonRequest", Optional[str]]: """Read a given number of mnemonic words from Trezor T screen and correctly answer confirmation questions. Return the full mnemonic. @@ -201,6 +218,7 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False): """ mnemonic = [] br = yield + assert br.pages is not None for _ in range(br.pages - 1): mnemonic.extend(debug.read_reset_word().split()) debug.swipe_up(wait=True) @@ -221,13 +239,15 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False): return " ".join(mnemonic) -def get_test_address(client): +def get_test_address(client: "Client") -> str: """Fetch a testnet address on a fixed path. Useful to make a pin/passphrase protected call, or to identify the root secret (seed+passphrase)""" return btc.get_address(client, "Testnet", TEST_ADDRESS_N) -def assert_tx_matches(serialized_tx: bytes, hash_link: str, tx_hex: str = None) -> None: +def assert_tx_matches( + serialized_tx: bytes, hash_link: str, tx_hex: str = None +) -> None: """Verifies if a transaction is correctly formed.""" hash_str = hash_link.split("/")[-1] assert tools.tx_hash(serialized_tx).hex() == hash_str diff --git a/tests/conftest.py b/tests/conftest.py index 48e319036..bfd381a8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ # If not, see . import os +from typing import TYPE_CHECKING import pytest @@ -27,12 +28,17 @@ from . import ui_tests from .device_handler import BackgroundDeviceHandler from .ui_tests.reporting import testreport +if TYPE_CHECKING: + from _pytest.config import Config + from _pytest.config.argparsing import Parser + from _pytest.terminal import TerminalReporter + # So that we see details of failed asserts from this module pytest.register_assert_rewrite("tests.common") @pytest.fixture(scope="session") -def _raw_client(request): +def _raw_client(request: pytest.FixtureRequest) -> TrezorClientDebugLink: path = os.environ.get("TREZOR_PATH") interact = int(os.environ.get("INTERACT", 0)) if path: @@ -56,7 +62,9 @@ def _raw_client(request): @pytest.fixture(scope="function") -def client(request, _raw_client): +def client( + request: pytest.FixtureRequest, _raw_client: TrezorClientDebugLink +) -> TrezorClientDebugLink: """Client fixture. Every test function that requires a client instance will get it from here. @@ -156,20 +164,20 @@ def client(request, _raw_client): _raw_client.close() -def pytest_sessionstart(session): +def pytest_sessionstart(session: pytest.Session) -> None: ui_tests.read_fixtures() if session.config.getoption("ui") == "test": testreport.clear_dir() -def _should_write_ui_report(exitstatus): +def _should_write_ui_report(exitstatus: pytest.ExitCode) -> bool: # generate UI report and check missing only if pytest is exitting cleanly # I.e., the test suite passed or failed (as opposed to ctrl+c break, internal error, # etc.) return exitstatus in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED) -def pytest_sessionfinish(session, exitstatus): +def pytest_sessionfinish(session: pytest.Session, exitstatus: pytest.ExitCode) -> None: if not _should_write_ui_report(exitstatus): return @@ -183,7 +191,9 @@ def pytest_sessionfinish(session, exitstatus): ui_tests.write_fixtures(missing) -def pytest_terminal_summary(terminalreporter, exitstatus, config): +def pytest_terminal_summary( + terminalreporter: "TerminalReporter", exitstatus: pytest.ExitCode, config: "Config" +) -> None: println = terminalreporter.write_line println("") @@ -213,7 +223,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): println("") -def pytest_addoption(parser): +def pytest_addoption(parser: "Parser") -> None: parser.addoption( "--ui", action="store", @@ -229,7 +239,7 @@ def pytest_addoption(parser): ) -def pytest_configure(config): +def pytest_configure(config: "Config") -> None: """Called at testsuite setup time. Registers known markers, enables verbose output if requested. @@ -253,7 +263,7 @@ def pytest_configure(config): log.enable_debug_output() -def pytest_runtest_setup(item): +def pytest_runtest_setup(item: pytest.Item) -> None: """Called for each test item (class, individual tests). Ensures that altcoin tests are skipped, and that no test is skipped on @@ -267,7 +277,7 @@ def pytest_runtest_setup(item): pytest.skip("Skipping altcoin test") -def pytest_runtest_teardown(item): +def pytest_runtest_teardown(item: pytest.Item) -> None: """Called after a test item finishes. Dumps the current UI test report HTML. @@ -277,7 +287,7 @@ def pytest_runtest_teardown(item): @pytest.hookimpl(tryfirst=True, hookwrapper=True) -def pytest_runtest_makereport(item, call): +def pytest_runtest_makereport(item: pytest.Item, call) -> None: # Make test results available in fixtures. # See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures # The device_handler fixture uses this as 'request.node.rep_call.passed' attribute, @@ -288,7 +298,9 @@ def pytest_runtest_makereport(item, call): @pytest.fixture -def device_handler(client, request): +def device_handler( + client: TrezorClientDebugLink, request: pytest.FixtureRequest +) -> None: device_handler = BackgroundDeviceHandler(client) yield device_handler @@ -299,5 +311,5 @@ def device_handler(client, request): # if test finished, make sure all background tasks are done finalized_ok = device_handler.check_finalize() - if request.node.rep_call.passed and not finalized_ok: + if request.node.rep_call.passed and not finalized_ok: # type: ignore [rep_call must exist] raise RuntimeError("Test did not check result of background task") diff --git a/tests/device_handler.py b/tests/device_handler.py index a703c6813..3ad14ed92 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -1,8 +1,15 @@ from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING from trezorlib.client import PASSPHRASE_ON_DEVICE from trezorlib.transport import udp +if TYPE_CHECKING: + from trezorlib.messages import Features + from trezorlib.debuglink import TrezorClientDebugLink, DebugLink + from trezorlib._internal.emulator import Emulator + + udp.SOCKET_TIMEOUT = 0.1 @@ -26,13 +33,13 @@ class NullUI: class BackgroundDeviceHandler: _pool = ThreadPoolExecutor() - def __init__(self, client): + def __init__(self, client: "TrezorClientDebugLink") -> None: self._configure_client(client) self.task = None - def _configure_client(self, client): + def _configure_client(self, client: "TrezorClientDebugLink") -> None: self.client = client - self.client.ui = NullUI + self.client.ui = NullUI # type: ignore [NullUI is OK UI] self.client.watch_layout(True) def run(self, function, *args, **kwargs): @@ -40,7 +47,7 @@ class BackgroundDeviceHandler: raise RuntimeError("Wait for previous task first") self.task = self._pool.submit(function, self.client, *args, **kwargs) - def kill_task(self): + def kill_task(self) -> None: if self.task is not None: # Force close the client, which should raise an exception in a client # waiting on IO. Does not work over Bridge, because bridge doesn't have @@ -53,11 +60,11 @@ class BackgroundDeviceHandler: pass self.task = None - def restart(self, emulator): + def restart(self, emulator: "Emulator"): # TODO handle actual restart as well self.kill_task() emulator.restart() - self._configure_client(emulator.client) + self._configure_client(emulator.client) # type: ignore [client cannot be None] def result(self): if self.task is None: @@ -67,25 +74,25 @@ class BackgroundDeviceHandler: finally: self.task = None - def features(self): + def features(self) -> "Features": if self.task is not None: raise RuntimeError("Cannot query features while task is running") self.client.init_device() return self.client.features - def debuglink(self): + def debuglink(self) -> "DebugLink": return self.client.debug - def check_finalize(self): + def check_finalize(self) -> bool: if self.task is not None: self.kill_task() return False return True - def __enter__(self): + def __enter__(self) -> "BackgroundDeviceHandler": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: finalized_ok = self.check_finalize() if exc_type is None and not finalized_ok: raise RuntimeError("Exit while task is unfinished") diff --git a/tests/emulators.py b/tests/emulators.py index bdadaa08c..85e899da2 100644 --- a/tests/emulators.py +++ b/tests/emulators.py @@ -17,8 +17,9 @@ import tempfile from collections import defaultdict from pathlib import Path +from typing import Dict, List, Tuple -from trezorlib._internal.emulator import CoreEmulator, LegacyEmulator +from trezorlib._internal.emulator import CoreEmulator, Emulator, LegacyEmulator ROOT = Path(__file__).resolve().parent.parent BINDIR = ROOT / "tests" / "emulators" @@ -33,18 +34,18 @@ CORE_SRC_DIR = ROOT / "core" / "src" ENV = {"SDL_VIDEODRIVER": "dummy"} -def check_version(tag, version_tuple): +def check_version(tag: str, version_tuple: Tuple[int, int, int]) -> None: if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3: version = ".".join(str(i) for i in version_tuple) if tag[1:] != version: raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}") -def filename_from_tag(gen, tag): +def filename_from_tag(gen: str, tag: str) -> Path: return BINDIR / f"trezor-emu-{gen}-{tag}" -def get_tags(): +def get_tags() -> Dict[str, List[str]]: files = list(BINDIR.iterdir()) if not files: raise ValueError( @@ -66,7 +67,7 @@ ALL_TAGS = get_tags() class EmulatorWrapper: - def __init__(self, gen, tag=None, storage=None): + def __init__(self, gen: str, tag: str = None, storage: bytes = None) -> None: if tag is not None: executable = filename_from_tag(gen, tag) else: @@ -96,11 +97,15 @@ class EmulatorWrapper: workdir=workdir, headless=True, ) + else: + raise ValueError( + f"Unrecognized gen - {gen} - only 'core' and 'legacy' supported" + ) - def __enter__(self): + def __enter__(self) -> Emulator: self.emulator.start() return self.emulator - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: self.emulator.stop() self.profile_dir.cleanup() diff --git a/tests/show_results.py b/tests/show_results.py index 832ff298f..ed53777c6 100755 --- a/tests/show_results.py +++ b/tests/show_results.py @@ -5,9 +5,9 @@ import multiprocessing import os import posixpath import time -import urllib import webbrowser from pathlib import Path +from urllib.parse import unquote import click @@ -16,16 +16,16 @@ TEST_RESULT_PATH = ROOT / "tests" / "ui_tests" / "reporting" / "reports" / "test class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler): - def end_headers(self): + def end_headers(self) -> None: self.send_header("Cache-Control", "no-cache, no-store, must-revalidate") self.send_header("Pragma", "no-cache") self.send_header("Expires", "0") return super().end_headers() - def log_message(self, format, *args): + def log_message(self, format, *args) -> None: pass - def translate_path(self, path): + def translate_path(self, path: str) -> str: # XXX # Copy-pasted from Python 3.8 BaseHTTPRequestHandler so that we can inject # the `directory` parameter. @@ -36,9 +36,9 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler): # Don't forget explicit trailing slash when normalizing. Issue17324 trailing_slash = path.rstrip().endswith("/") try: - path = urllib.parse.unquote(path, errors="surrogatepass") + path = unquote(path, errors="surrogatepass") except UnicodeDecodeError: - path = urllib.parse.unquote(path) + path = unquote(path) path = posixpath.normpath(path) words = path.split("/") words = filter(None, words) @@ -53,13 +53,13 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler): return path -def launch_http_server(port): - http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port) +def launch_http_server(port: int) -> None: + http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port) # type: ignore [test is defined] @click.command() @click.option("-p", "--port", type=int, default=8000) -def main(port): +def main(port: int): httpd = multiprocessing.Process(target=launch_http_server, args=(port,)) httpd.start() time.sleep(0.5) diff --git a/tests/tx_cache.py b/tests/tx_cache.py index b7517ea31..5c2ae227e 100755 --- a/tests/tx_cache.py +++ b/tests/tx_cache.py @@ -122,7 +122,7 @@ def cli(tx, coin_name): tx_dict = protobuf.to_dict(tx_proto) tx_json = json.dumps(tx_dict, sort_keys=True, indent=2) + "\n" except Exception as e: - raise click.ClickException(e) from e + raise click.ClickException(str(e)) from e cache_dir = CACHE_PATH / coin_name if not cache_dir.exists(): diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index 03aa9562d..20a758212 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -4,23 +4,26 @@ import re import shutil from contextlib import contextmanager from pathlib import Path +from typing import Dict, Generator, Set import pytest from _pytest.outcomes import Failed from PIL import Image +from trezorlib.debuglink import TrezorClientDebugLink as Client + from .reporting import testreport UI_TESTS_DIR = Path(__file__).resolve().parent SCREENS_DIR = UI_TESTS_DIR / "screens" HASH_FILE = UI_TESTS_DIR / "fixtures.json" SUGGESTION_FILE = UI_TESTS_DIR / "fixtures.suggestion.json" -FILE_HASHES = {} -ACTUAL_HASHES = {} -PROCESSED = set() +FILE_HASHES: Dict[str, str] = {} +ACTUAL_HASHES: Dict[str, str] = {} +PROCESSED: Set[str] = set() -def get_test_name(node_id): +def get_test_name(node_id: str) -> str: # Test item name is usually function name, but when parametrization is used, # parameters are also part of the name. Some functions have very long parameter # names (tx hashes etc) that run out of maximum allowable filename length, so @@ -34,14 +37,14 @@ def get_test_name(node_id): return new_name[:91] + "-" + hashlib.sha256(new_name.encode()).hexdigest()[:8] -def _process_recorded(screen_path, test_name): +def _process_recorded(screen_path: Path, test_name: str) -> None: # calculate hash FILE_HASHES[test_name] = _hash_files(screen_path) _rename_records(screen_path) PROCESSED.add(test_name) -def _rename_records(screen_path): +def _rename_records(screen_path: Path) -> None: # rename screenshots for index, record in enumerate(sorted(screen_path.iterdir())): record.replace(screen_path / f"{index:08}.png") @@ -65,7 +68,7 @@ def _get_bytes_from_png(png_file: str) -> bytes: return Image.open(png_file).tobytes() -def _process_tested(fixture_test_path, test_name): +def _process_tested(fixture_test_path: Path, test_name: str) -> None: PROCESSED.add(test_name) actual_path = fixture_test_path / "actual" @@ -79,6 +82,7 @@ def _process_tested(fixture_test_path, test_name): pytest.fail(f"Hash of {test_name} not found in fixtures.json") if actual_hash != expected_hash: + assert expected_hash is not None file_path = testreport.failed( fixture_test_path, test_name, actual_hash, expected_hash ) @@ -94,7 +98,9 @@ def _process_tested(fixture_test_path, test_name): @contextmanager -def screen_recording(client, request): +def screen_recording( + client: Client, request: pytest.FixtureRequest +) -> Generator[None, None, None]: test_ui = request.config.getoption("ui") test_name = get_test_name(request.node.nodeid) screens_test_path = SCREENS_DIR / test_name @@ -126,26 +132,26 @@ def screen_recording(client, request): _process_tested(screens_test_path, test_name) -def list_missing(): +def list_missing() -> Set[str]: return set(FILE_HASHES.keys()) - PROCESSED -def read_fixtures(): +def read_fixtures() -> None: if not HASH_FILE.exists(): raise ValueError("File fixtures.json not found.") global FILE_HASHES FILE_HASHES = json.loads(HASH_FILE.read_text()) -def write_fixtures(remove_missing: bool): +def write_fixtures(remove_missing: bool) -> None: HASH_FILE.write_text(_get_fixtures_content(FILE_HASHES, remove_missing)) -def write_fixtures_suggestion(remove_missing: bool): +def write_fixtures_suggestion(remove_missing: bool) -> None: SUGGESTION_FILE.write_text(_get_fixtures_content(ACTUAL_HASHES, remove_missing)) -def _get_fixtures_content(fixtures: dict, remove_missing: bool): +def _get_fixtures_content(fixtures: Dict[str, str], remove_missing: bool) -> str: if remove_missing: fixtures = {i: fixtures[i] for i in PROCESSED} else: @@ -154,7 +160,7 @@ def _get_fixtures_content(fixtures: dict, remove_missing: bool): return json.dumps(fixtures, indent="", sort_keys=True) + "\n" -def main(): +def main() -> None: read_fixtures() for record in SCREENS_DIR.iterdir(): if not (record / "actual").exists(): diff --git a/tests/ui_tests/reporting/download.py b/tests/ui_tests/reporting/download.py index e432d6140..3ffbb4817 100644 --- a/tests/ui_tests/reporting/download.py +++ b/tests/ui_tests/reporting/download.py @@ -14,7 +14,7 @@ FIXTURES_CURRENT = Path(__file__).resolve().parent.parent / "fixtures.json" _dns_failed = False -def fetch_recorded(hash, path): +def fetch_recorded(hash: str, path: Path) -> None: global _dns_failed if _dns_failed: diff --git a/tests/ui_tests/reporting/html.py b/tests/ui_tests/reporting/html.py index c4bbeef6c..611f362dd 100644 --- a/tests/ui_tests/reporting/html.py +++ b/tests/ui_tests/reporting/html.py @@ -1,11 +1,15 @@ import base64 import filecmp from itertools import zip_longest +from pathlib import Path +from typing import Dict, List from dominate.tags import a, i, img, table, td, th, tr -def report_links(tests, reports_path, actual_hashes=None): +def report_links( + tests: List[Path], reports_path: Path, actual_hashes: Dict[str, str] = None +) -> None: if actual_hashes is None: actual_hashes = {} @@ -21,12 +25,12 @@ def report_links(tests, reports_path, actual_hashes=None): td(a(test.name, href=path)) -def write(fixture_test_path, doc, filename): +def write(fixture_test_path: Path, doc, filename: str) -> Path: (fixture_test_path / filename).write_text(doc.render()) return fixture_test_path / filename -def image(src): +def image(src: Path) -> None: with td(): if src: # open image file @@ -41,7 +45,7 @@ def image(src): i("missing") -def diff_table(left_screens, right_screens): +def diff_table(left_screens: List[Path], right_screens: List[Path]) -> None: for left, right in zip_longest(left_screens, right_screens): if left and right and filecmp.cmp(right, left): background = "white" diff --git a/tests/ui_tests/reporting/report_master_diff.py b/tests/ui_tests/reporting/report_master_diff.py index 4f3ec9c01..ca294a1a7 100644 --- a/tests/ui_tests/reporting/report_master_diff.py +++ b/tests/ui_tests/reporting/report_master_diff.py @@ -2,6 +2,7 @@ import shutil import tempfile from contextlib import contextmanager from pathlib import Path +from typing import Dict, Tuple import dominate from dominate.tags import br, h1, h2, hr, i, p, table, td, th, tr @@ -14,7 +15,7 @@ REPORTS_PATH = Path(__file__).resolve().parent / "reports" / "master_diff" RECORDED_SCREENS_PATH = Path(__file__).resolve().parent.parent / "screens" -def get_diff(): +def get_diff() -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]: master = download.fetch_fixtures_master() current = download.fetch_fixtures_current() @@ -34,7 +35,7 @@ def get_diff(): return removed, added, diff -def removed(screens_path, test_name): +def removed(screens_path: Path, test_name: str) -> Path: doc = dominate.document(title=test_name) screens = sorted(screens_path.iterdir()) @@ -57,7 +58,7 @@ def removed(screens_path, test_name): return html.write(REPORTS_PATH / "removed", doc, test_name + ".html") -def added(screens_path, test_name): +def added(screens_path: Path, test_name: str) -> Path: doc = dominate.document(title=test_name) screens = sorted(screens_path.iterdir()) @@ -81,8 +82,12 @@ def added(screens_path, test_name): def diff( - master_screens_path, current_screens_path, test_name, master_hash, current_hash -): + master_screens_path: Path, + current_screens_path: Path, + test_name: str, + master_hash: str, + current_hash: str, +) -> Path: doc = dominate.document(title=test_name) master_screens = sorted(master_screens_path.iterdir()) current_screens = sorted(current_screens_path.iterdir()) @@ -109,7 +114,7 @@ def diff( return html.write(REPORTS_PATH / "diff", doc, test_name + ".html") -def index(): +def index() -> Path: removed = list((REPORTS_PATH / "removed").iterdir()) added = list((REPORTS_PATH / "added").iterdir()) diff = list((REPORTS_PATH / "diff").iterdir()) @@ -140,7 +145,7 @@ def index(): return html.write(REPORTS_PATH, doc, "index.html") -def create_dirs(): +def create_dirs() -> None: # delete the reports dir to clear previous entries and create folders shutil.rmtree(REPORTS_PATH, ignore_errors=True) REPORTS_PATH.mkdir() @@ -149,7 +154,7 @@ def create_dirs(): (REPORTS_PATH / "diff").mkdir() -def create_reports(): +def create_reports() -> None: removed_tests, added_tests, diff_tests = get_diff() @contextmanager diff --git a/tests/ui_tests/reporting/testreport.py b/tests/ui_tests/reporting/testreport.py index 36f912557..03e054a14 100644 --- a/tests/ui_tests/reporting/testreport.py +++ b/tests/ui_tests/reporting/testreport.py @@ -2,6 +2,7 @@ import shutil from datetime import datetime from distutils.dir_util import copy_tree from pathlib import Path +from typing import Dict import dominate import dominate.tags as t @@ -16,10 +17,12 @@ REPORTS_PATH = HERE / "reports" / "test" STYLE = (HERE / "testreport.css").read_text() SCRIPT = (HERE / "testreport.js").read_text() -ACTUAL_HASHES = {} +ACTUAL_HASHES: Dict[str, str] = {} -def document(title, actual_hash=None, index=False): +def document( + title: str, actual_hash: str = None, index: bool = False +) -> dominate.document: doc = dominate.document(title=title) style = t.style() style.add_raw_string(STYLE) @@ -36,7 +39,7 @@ def document(title, actual_hash=None, index=False): return doc -def _header(test_name, expected_hash, actual_hash): +def _header(test_name: str, expected_hash: str, actual_hash: str) -> None: h1(test_name) with div(): if actual_hash == expected_hash: @@ -54,7 +57,7 @@ def _header(test_name, expected_hash, actual_hash): hr() -def clear_dir(): +def clear_dir() -> None: # delete and create the reports dir to clear previous entries shutil.rmtree(REPORTS_PATH, ignore_errors=True) REPORTS_PATH.mkdir() @@ -62,7 +65,7 @@ def clear_dir(): (REPORTS_PATH / "passed").mkdir() -def index(): +def index() -> Path: passed_tests = list((REPORTS_PATH / "passed").iterdir()) failed_tests = list((REPORTS_PATH / "failed").iterdir()) @@ -104,7 +107,9 @@ def index(): return html.write(REPORTS_PATH, doc, "index.html") -def failed(fixture_test_path, test_name, actual_hash, expected_hash): +def failed( + fixture_test_path: Path, test_name: str, actual_hash: str, expected_hash: str +) -> Path: ACTUAL_HASHES[test_name] = actual_hash doc = document(title=test_name, actual_hash=actual_hash) @@ -147,7 +152,7 @@ def failed(fixture_test_path, test_name, actual_hash, expected_hash): return html.write(REPORTS_PATH / "failed", doc, test_name + ".html") -def passed(fixture_test_path, test_name, actual_hash): +def passed(fixture_test_path: Path, test_name: str, actual_hash: str) -> Path: copy_tree(str(fixture_test_path / "actual"), str(fixture_test_path / "recorded")) doc = document(title=test_name) diff --git a/tests/upgrade_tests/__init__.py b/tests/upgrade_tests/__init__.py index d6b9d6685..5b922a142 100644 --- a/tests/upgrade_tests/__init__.py +++ b/tests/upgrade_tests/__init__.py @@ -15,8 +15,10 @@ # If not, see . import os +from typing import List, Tuple import pytest +from _pytest.mark.structures import MarkDecorator from ..emulators import ALL_TAGS, LOCAL_BUILD_PATHS @@ -44,7 +46,11 @@ core_only = pytest.mark.skipif( ) -def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0, 0)): +def for_all( + *args, + legacy_minimum_version: Tuple[int, int, int] = (1, 0, 0), + core_minimum_version: Tuple[int, int, int] = (2, 0, 0) +) -> "MarkDecorator": """Parametrizing decorator for test cases. Usage example: @@ -98,7 +104,7 @@ def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0, return pytest.mark.parametrize("gen, tag", all_params) -def for_tags(*args): +def for_tags(*args: Tuple[str, List[str]]) -> "MarkDecorator": enabled_gens = SELECTED_GENS or ("core", "legacy") return pytest.mark.parametrize( "gen, tags", [(gen, tags) for gen, tags in args if gen in enabled_gens]