chore(tests): add type hints to helper test functions

pull/2107/head
grdddj 2 years ago committed by matejcik
parent 5d76144ef5
commit c77e18d77c

@ -18,6 +18,7 @@ import hashlib
import hmac
import struct
from copy import copy
from typing import Any, List, Tuple
import ecdsa
from ecdsa.curves import SECP256k1
@ -27,7 +28,7 @@ from ecdsa.util import number_to_string, string_to_number
from trezorlib import messages, tools
def point_to_pubkey(point):
def point_to_pubkey(point: Point) -> bytes:
order = SECP256k1.order
x_str = number_to_string(point.x(), order)
y_str = number_to_string(point.y(), order)
@ -35,14 +36,14 @@ def point_to_pubkey(point):
return struct.pack("B", (vk[63] & 1) + 2) + vk[0:32] # To compressed key
def sec_to_public_pair(pubkey):
def sec_to_public_pair(pubkey: bytes) -> Tuple[int, Any]:
"""Convert a public key in sec binary format to a public pair."""
x = string_to_number(pubkey[1:33])
sec0 = pubkey[:1]
if sec0 not in (b"\2", b"\3"):
raise ValueError("Compressed pubkey expected")
def public_pair_for_x(generator, x, is_even):
def public_pair_for_x(generator, x: int, is_even: bool) -> Tuple[int, Any]:
curve = generator.curve()
p = curve.p()
alpha = (pow(x, 3, p) + curve.a() * x + curve.b()) % p
@ -56,15 +57,15 @@ def sec_to_public_pair(pubkey):
)
def fingerprint(pubkey):
def fingerprint(pubkey: bytes) -> int:
return string_to_number(tools.hash_160(pubkey)[:4])
def get_address(public_node, address_type):
def get_address(public_node: messages.HDNodeType, address_type: int) -> str:
return tools.public_key_to_bc_address(public_node.public_key, address_type)
def public_ckd(public_node, n):
def public_ckd(public_node: messages.HDNodeType, n: List[int]):
if not isinstance(n, list):
raise ValueError("Parameter must be a list")
@ -76,7 +77,7 @@ def public_ckd(public_node, n):
return node
def get_subnode(node, i):
def get_subnode(node: messages.HDNodeType, i: int) -> messages.HDNodeType:
# Public Child key derivation (CKD) algorithm of BIP32
i_as_bytes = struct.pack(">L", i)
@ -108,7 +109,7 @@ def get_subnode(node, i):
)
def serialize(node, version=0x0488B21E):
def serialize(node: messages.HDNodeType, version: int = 0x0488B21E) -> str:
s = b""
s += struct.pack(">I", version)
s += struct.pack(">B", node.depth)
@ -123,7 +124,7 @@ def serialize(node, version=0x0488B21E):
return tools.b58encode(s)
def deserialize(xpub):
def deserialize(xpub: str) -> messages.HDNodeType:
data = tools.b58decode(xpub, None)
if tools.btc_hash(data[:-4])[:4] != data[-4:]:

@ -1,8 +1,10 @@
from typing import Iterator, Tuple
DISPLAY_WIDTH = 240
DISPLAY_HEIGHT = 240
def grid(dim, grid_cells, cell):
def grid(dim: int, grid_cells: int, cell: int) -> int:
step = dim // grid_cells
ofs = step // 2
return cell * step + ofs
@ -34,15 +36,15 @@ RESET_WORD_CHECK = [
BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
def grid35(x, y):
def grid35(x: int, y: int) -> Tuple[int, int]:
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y)
def grid34(x, y):
def grid34(x: int, y: int) -> Tuple[int, int]:
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y)
def type_word(word):
def type_word(word: str) -> Iterator[Tuple[int, int]]:
for l in word:
idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters)
grid_x = idx % 3

@ -18,12 +18,18 @@ import json
import os
from decimal import Decimal
from pathlib import Path
from typing import TYPE_CHECKING, Generator, List, Optional
import pytest
import requests
from trezorlib import btc, tools
from trezorlib.messages import ButtonRequestType as B
from trezorlib.messages import ButtonRequestType
if TYPE_CHECKING:
from trezorlib.debuglink import DebugLink, TrezorClientDebugLink as Client
from trezorlib.messages import ButtonRequest
from _pytest.mark.structures import MarkDecorator
# fmt: off
# 1 2 3 4 5 6 7 8 9 10 11 12
@ -56,7 +62,7 @@ COMMON_FIXTURES_DIR = (
)
def parametrize_using_common_fixtures(*paths):
def parametrize_using_common_fixtures(*paths: str) -> "MarkDecorator":
fixtures = []
for path in paths:
fixtures.append(json.loads((COMMON_FIXTURES_DIR / path).read_text()))
@ -85,7 +91,9 @@ def parametrize_using_common_fixtures(*paths):
return pytest.mark.parametrize("parameters, result", tests)
def generate_entropy(strength, internal_entropy, external_entropy):
def generate_entropy(
strength: int, internal_entropy: bytes, external_entropy: bytes
) -> bytes:
"""
strength - length of produced seed. One of 128, 192, 256
random - binary stream of random data from external HRNG
@ -116,7 +124,12 @@ def generate_entropy(strength, internal_entropy, external_entropy):
return entropy_stripped
def recovery_enter_shares(debug, shares, groups=False, click_info=False):
def recovery_enter_shares(
debug: "DebugLink",
shares: List[str],
groups: bool = False,
click_info: bool = False,
) -> Generator[None, "ButtonRequest", None]:
"""Perform the recovery flow for a set of Shamir shares.
For use in an input flow function.
@ -134,7 +147,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
debug.press_yes()
# Input word number
br = yield
assert br.code == B.MnemonicWordCount
assert br.code == ButtonRequestType.MnemonicWordCount
debug.input(str(word_count))
# Homescreen - proceed to share entry
yield
@ -142,7 +155,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
# Enter shares
for share in shares:
br = yield
assert br.code == B.MnemonicInput
assert br.code == ButtonRequestType.MnemonicInput
# Enter mnemonic words
for word in share.split(" "):
debug.input(word)
@ -167,7 +180,9 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
debug.press_yes()
def click_through(debug, screens, code=None):
def click_through(
debug: "DebugLink", screens: int, code: ButtonRequestType = None
) -> Generator[None, "ButtonRequest", None]:
"""Click through N dialog screens.
For use in an input flow function.
@ -178,7 +193,7 @@ def click_through(debug, screens, code=None):
# 2. Backup your seed
# 3. Confirm warning
# 4. Shares info
yield from click_through(client.debug, screens=4, code=B.ResetDevice)
yield from click_through(client.debug, screens=4, code=ButtonRequestType.ResetDevice)
"""
for _ in range(screens):
received = yield
@ -187,7 +202,9 @@ def click_through(debug, screens, code=None):
debug.press_yes()
def read_and_confirm_mnemonic(debug, choose_wrong=False):
def read_and_confirm_mnemonic(
debug: "DebugLink", choose_wrong: bool = False
) -> Generator[None, "ButtonRequest", Optional[str]]:
"""Read a given number of mnemonic words from Trezor T screen and correctly
answer confirmation questions. Return the full mnemonic.
@ -201,6 +218,7 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
"""
mnemonic = []
br = yield
assert br.pages is not None
for _ in range(br.pages - 1):
mnemonic.extend(debug.read_reset_word().split())
debug.swipe_up(wait=True)
@ -221,13 +239,15 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
return " ".join(mnemonic)
def get_test_address(client):
def get_test_address(client: "Client") -> str:
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
protected call, or to identify the root secret (seed+passphrase)"""
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
def assert_tx_matches(serialized_tx: bytes, hash_link: str, tx_hex: str = None) -> None:
def assert_tx_matches(
serialized_tx: bytes, hash_link: str, tx_hex: str = None
) -> None:
"""Verifies if a transaction is correctly formed."""
hash_str = hash_link.split("/")[-1]
assert tools.tx_hash(serialized_tx).hex() == hash_str

@ -15,6 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os
from typing import TYPE_CHECKING
import pytest
@ -27,12 +28,17 @@ from . import ui_tests
from .device_handler import BackgroundDeviceHandler
from .ui_tests.reporting import testreport
if TYPE_CHECKING:
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.terminal import TerminalReporter
# So that we see details of failed asserts from this module
pytest.register_assert_rewrite("tests.common")
@pytest.fixture(scope="session")
def _raw_client(request):
def _raw_client(request: pytest.FixtureRequest) -> TrezorClientDebugLink:
path = os.environ.get("TREZOR_PATH")
interact = int(os.environ.get("INTERACT", 0))
if path:
@ -56,7 +62,9 @@ def _raw_client(request):
@pytest.fixture(scope="function")
def client(request, _raw_client):
def client(
request: pytest.FixtureRequest, _raw_client: TrezorClientDebugLink
) -> TrezorClientDebugLink:
"""Client fixture.
Every test function that requires a client instance will get it from here.
@ -156,20 +164,20 @@ def client(request, _raw_client):
_raw_client.close()
def pytest_sessionstart(session):
def pytest_sessionstart(session: pytest.Session) -> None:
ui_tests.read_fixtures()
if session.config.getoption("ui") == "test":
testreport.clear_dir()
def _should_write_ui_report(exitstatus):
def _should_write_ui_report(exitstatus: pytest.ExitCode) -> bool:
# generate UI report and check missing only if pytest is exitting cleanly
# I.e., the test suite passed or failed (as opposed to ctrl+c break, internal error,
# etc.)
return exitstatus in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED)
def pytest_sessionfinish(session, exitstatus):
def pytest_sessionfinish(session: pytest.Session, exitstatus: pytest.ExitCode) -> None:
if not _should_write_ui_report(exitstatus):
return
@ -183,7 +191,9 @@ def pytest_sessionfinish(session, exitstatus):
ui_tests.write_fixtures(missing)
def pytest_terminal_summary(terminalreporter, exitstatus, config):
def pytest_terminal_summary(
terminalreporter: "TerminalReporter", exitstatus: pytest.ExitCode, config: "Config"
) -> None:
println = terminalreporter.write_line
println("")
@ -213,7 +223,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
println("")
def pytest_addoption(parser):
def pytest_addoption(parser: "Parser") -> None:
parser.addoption(
"--ui",
action="store",
@ -229,7 +239,7 @@ def pytest_addoption(parser):
)
def pytest_configure(config):
def pytest_configure(config: "Config") -> None:
"""Called at testsuite setup time.
Registers known markers, enables verbose output if requested.
@ -253,7 +263,7 @@ def pytest_configure(config):
log.enable_debug_output()
def pytest_runtest_setup(item):
def pytest_runtest_setup(item: pytest.Item) -> None:
"""Called for each test item (class, individual tests).
Ensures that altcoin tests are skipped, and that no test is skipped on
@ -267,7 +277,7 @@ def pytest_runtest_setup(item):
pytest.skip("Skipping altcoin test")
def pytest_runtest_teardown(item):
def pytest_runtest_teardown(item: pytest.Item) -> None:
"""Called after a test item finishes.
Dumps the current UI test report HTML.
@ -277,7 +287,7 @@ def pytest_runtest_teardown(item):
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item, call):
def pytest_runtest_makereport(item: pytest.Item, call) -> None:
# Make test results available in fixtures.
# See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
# The device_handler fixture uses this as 'request.node.rep_call.passed' attribute,
@ -288,7 +298,9 @@ def pytest_runtest_makereport(item, call):
@pytest.fixture
def device_handler(client, request):
def device_handler(
client: TrezorClientDebugLink, request: pytest.FixtureRequest
) -> None:
device_handler = BackgroundDeviceHandler(client)
yield device_handler
@ -299,5 +311,5 @@ def device_handler(client, request):
# if test finished, make sure all background tasks are done
finalized_ok = device_handler.check_finalize()
if request.node.rep_call.passed and not finalized_ok:
if request.node.rep_call.passed and not finalized_ok: # type: ignore [rep_call must exist]
raise RuntimeError("Test did not check result of background task")

@ -1,8 +1,15 @@
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING
from trezorlib.client import PASSPHRASE_ON_DEVICE
from trezorlib.transport import udp
if TYPE_CHECKING:
from trezorlib.messages import Features
from trezorlib.debuglink import TrezorClientDebugLink, DebugLink
from trezorlib._internal.emulator import Emulator
udp.SOCKET_TIMEOUT = 0.1
@ -26,13 +33,13 @@ class NullUI:
class BackgroundDeviceHandler:
_pool = ThreadPoolExecutor()
def __init__(self, client):
def __init__(self, client: "TrezorClientDebugLink") -> None:
self._configure_client(client)
self.task = None
def _configure_client(self, client):
def _configure_client(self, client: "TrezorClientDebugLink") -> None:
self.client = client
self.client.ui = NullUI
self.client.ui = NullUI # type: ignore [NullUI is OK UI]
self.client.watch_layout(True)
def run(self, function, *args, **kwargs):
@ -40,7 +47,7 @@ class BackgroundDeviceHandler:
raise RuntimeError("Wait for previous task first")
self.task = self._pool.submit(function, self.client, *args, **kwargs)
def kill_task(self):
def kill_task(self) -> None:
if self.task is not None:
# Force close the client, which should raise an exception in a client
# waiting on IO. Does not work over Bridge, because bridge doesn't have
@ -53,11 +60,11 @@ class BackgroundDeviceHandler:
pass
self.task = None
def restart(self, emulator):
def restart(self, emulator: "Emulator"):
# TODO handle actual restart as well
self.kill_task()
emulator.restart()
self._configure_client(emulator.client)
self._configure_client(emulator.client) # type: ignore [client cannot be None]
def result(self):
if self.task is None:
@ -67,25 +74,25 @@ class BackgroundDeviceHandler:
finally:
self.task = None
def features(self):
def features(self) -> "Features":
if self.task is not None:
raise RuntimeError("Cannot query features while task is running")
self.client.init_device()
return self.client.features
def debuglink(self):
def debuglink(self) -> "DebugLink":
return self.client.debug
def check_finalize(self):
def check_finalize(self) -> bool:
if self.task is not None:
self.kill_task()
return False
return True
def __enter__(self):
def __enter__(self) -> "BackgroundDeviceHandler":
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
finalized_ok = self.check_finalize()
if exc_type is None and not finalized_ok:
raise RuntimeError("Exit while task is unfinished")

@ -17,8 +17,9 @@
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from trezorlib._internal.emulator import CoreEmulator, LegacyEmulator
from trezorlib._internal.emulator import CoreEmulator, Emulator, LegacyEmulator
ROOT = Path(__file__).resolve().parent.parent
BINDIR = ROOT / "tests" / "emulators"
@ -33,18 +34,18 @@ CORE_SRC_DIR = ROOT / "core" / "src"
ENV = {"SDL_VIDEODRIVER": "dummy"}
def check_version(tag, version_tuple):
def check_version(tag: str, version_tuple: Tuple[int, int, int]) -> None:
if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3:
version = ".".join(str(i) for i in version_tuple)
if tag[1:] != version:
raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}")
def filename_from_tag(gen, tag):
def filename_from_tag(gen: str, tag: str) -> Path:
return BINDIR / f"trezor-emu-{gen}-{tag}"
def get_tags():
def get_tags() -> Dict[str, List[str]]:
files = list(BINDIR.iterdir())
if not files:
raise ValueError(
@ -66,7 +67,7 @@ ALL_TAGS = get_tags()
class EmulatorWrapper:
def __init__(self, gen, tag=None, storage=None):
def __init__(self, gen: str, tag: str = None, storage: bytes = None) -> None:
if tag is not None:
executable = filename_from_tag(gen, tag)
else:
@ -96,11 +97,15 @@ class EmulatorWrapper:
workdir=workdir,
headless=True,
)
else:
raise ValueError(
f"Unrecognized gen - {gen} - only 'core' and 'legacy' supported"
)
def __enter__(self):
def __enter__(self) -> Emulator:
self.emulator.start()
return self.emulator
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.emulator.stop()
self.profile_dir.cleanup()

@ -5,9 +5,9 @@ import multiprocessing
import os
import posixpath
import time
import urllib
import webbrowser
from pathlib import Path
from urllib.parse import unquote
import click
@ -16,16 +16,16 @@ TEST_RESULT_PATH = ROOT / "tests" / "ui_tests" / "reporting" / "reports" / "test
class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
def end_headers(self):
def end_headers(self) -> None:
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
self.send_header("Pragma", "no-cache")
self.send_header("Expires", "0")
return super().end_headers()
def log_message(self, format, *args):
def log_message(self, format, *args) -> None:
pass
def translate_path(self, path):
def translate_path(self, path: str) -> str:
# XXX
# Copy-pasted from Python 3.8 BaseHTTPRequestHandler so that we can inject
# the `directory` parameter.
@ -36,9 +36,9 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
# Don't forget explicit trailing slash when normalizing. Issue17324
trailing_slash = path.rstrip().endswith("/")
try:
path = urllib.parse.unquote(path, errors="surrogatepass")
path = unquote(path, errors="surrogatepass")
except UnicodeDecodeError:
path = urllib.parse.unquote(path)
path = unquote(path)
path = posixpath.normpath(path)
words = path.split("/")
words = filter(None, words)
@ -53,13 +53,13 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
return path
def launch_http_server(port):
http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port)
def launch_http_server(port: int) -> None:
http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port) # type: ignore [test is defined]
@click.command()
@click.option("-p", "--port", type=int, default=8000)
def main(port):
def main(port: int):
httpd = multiprocessing.Process(target=launch_http_server, args=(port,))
httpd.start()
time.sleep(0.5)

@ -122,7 +122,7 @@ def cli(tx, coin_name):
tx_dict = protobuf.to_dict(tx_proto)
tx_json = json.dumps(tx_dict, sort_keys=True, indent=2) + "\n"
except Exception as e:
raise click.ClickException(e) from e
raise click.ClickException(str(e)) from e
cache_dir = CACHE_PATH / coin_name
if not cache_dir.exists():

@ -4,23 +4,26 @@ import re
import shutil
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Generator, Set
import pytest
from _pytest.outcomes import Failed
from PIL import Image
from trezorlib.debuglink import TrezorClientDebugLink as Client
from .reporting import testreport
UI_TESTS_DIR = Path(__file__).resolve().parent
SCREENS_DIR = UI_TESTS_DIR / "screens"
HASH_FILE = UI_TESTS_DIR / "fixtures.json"
SUGGESTION_FILE = UI_TESTS_DIR / "fixtures.suggestion.json"
FILE_HASHES = {}
ACTUAL_HASHES = {}
PROCESSED = set()
FILE_HASHES: Dict[str, str] = {}
ACTUAL_HASHES: Dict[str, str] = {}
PROCESSED: Set[str] = set()
def get_test_name(node_id):
def get_test_name(node_id: str) -> str:
# Test item name is usually function name, but when parametrization is used,
# parameters are also part of the name. Some functions have very long parameter
# names (tx hashes etc) that run out of maximum allowable filename length, so
@ -34,14 +37,14 @@ def get_test_name(node_id):
return new_name[:91] + "-" + hashlib.sha256(new_name.encode()).hexdigest()[:8]
def _process_recorded(screen_path, test_name):
def _process_recorded(screen_path: Path, test_name: str) -> None:
# calculate hash
FILE_HASHES[test_name] = _hash_files(screen_path)
_rename_records(screen_path)
PROCESSED.add(test_name)
def _rename_records(screen_path):
def _rename_records(screen_path: Path) -> None:
# rename screenshots
for index, record in enumerate(sorted(screen_path.iterdir())):
record.replace(screen_path / f"{index:08}.png")
@ -65,7 +68,7 @@ def _get_bytes_from_png(png_file: str) -> bytes:
return Image.open(png_file).tobytes()
def _process_tested(fixture_test_path, test_name):
def _process_tested(fixture_test_path: Path, test_name: str) -> None:
PROCESSED.add(test_name)
actual_path = fixture_test_path / "actual"
@ -79,6 +82,7 @@ def _process_tested(fixture_test_path, test_name):
pytest.fail(f"Hash of {test_name} not found in fixtures.json")
if actual_hash != expected_hash:
assert expected_hash is not None
file_path = testreport.failed(
fixture_test_path, test_name, actual_hash, expected_hash
)
@ -94,7 +98,9 @@ def _process_tested(fixture_test_path, test_name):
@contextmanager
def screen_recording(client, request):
def screen_recording(
client: Client, request: pytest.FixtureRequest
) -> Generator[None, None, None]:
test_ui = request.config.getoption("ui")
test_name = get_test_name(request.node.nodeid)
screens_test_path = SCREENS_DIR / test_name
@ -126,26 +132,26 @@ def screen_recording(client, request):
_process_tested(screens_test_path, test_name)
def list_missing():
def list_missing() -> Set[str]:
return set(FILE_HASHES.keys()) - PROCESSED
def read_fixtures():
def read_fixtures() -> None:
if not HASH_FILE.exists():
raise ValueError("File fixtures.json not found.")
global FILE_HASHES
FILE_HASHES = json.loads(HASH_FILE.read_text())
def write_fixtures(remove_missing: bool):
def write_fixtures(remove_missing: bool) -> None:
HASH_FILE.write_text(_get_fixtures_content(FILE_HASHES, remove_missing))
def write_fixtures_suggestion(remove_missing: bool):
def write_fixtures_suggestion(remove_missing: bool) -> None:
SUGGESTION_FILE.write_text(_get_fixtures_content(ACTUAL_HASHES, remove_missing))
def _get_fixtures_content(fixtures: dict, remove_missing: bool):
def _get_fixtures_content(fixtures: Dict[str, str], remove_missing: bool) -> str:
if remove_missing:
fixtures = {i: fixtures[i] for i in PROCESSED}
else:
@ -154,7 +160,7 @@ def _get_fixtures_content(fixtures: dict, remove_missing: bool):
return json.dumps(fixtures, indent="", sort_keys=True) + "\n"
def main():
def main() -> None:
read_fixtures()
for record in SCREENS_DIR.iterdir():
if not (record / "actual").exists():

@ -14,7 +14,7 @@ FIXTURES_CURRENT = Path(__file__).resolve().parent.parent / "fixtures.json"
_dns_failed = False
def fetch_recorded(hash, path):
def fetch_recorded(hash: str, path: Path) -> None:
global _dns_failed
if _dns_failed:

@ -1,11 +1,15 @@
import base64
import filecmp
from itertools import zip_longest
from pathlib import Path
from typing import Dict, List
from dominate.tags import a, i, img, table, td, th, tr
def report_links(tests, reports_path, actual_hashes=None):
def report_links(
tests: List[Path], reports_path: Path, actual_hashes: Dict[str, str] = None
) -> None:
if actual_hashes is None:
actual_hashes = {}
@ -21,12 +25,12 @@ def report_links(tests, reports_path, actual_hashes=None):
td(a(test.name, href=path))
def write(fixture_test_path, doc, filename):
def write(fixture_test_path: Path, doc, filename: str) -> Path:
(fixture_test_path / filename).write_text(doc.render())
return fixture_test_path / filename
def image(src):
def image(src: Path) -> None:
with td():
if src:
# open image file
@ -41,7 +45,7 @@ def image(src):
i("missing")
def diff_table(left_screens, right_screens):
def diff_table(left_screens: List[Path], right_screens: List[Path]) -> None:
for left, right in zip_longest(left_screens, right_screens):
if left and right and filecmp.cmp(right, left):
background = "white"

@ -2,6 +2,7 @@ import shutil
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Tuple
import dominate
from dominate.tags import br, h1, h2, hr, i, p, table, td, th, tr
@ -14,7 +15,7 @@ REPORTS_PATH = Path(__file__).resolve().parent / "reports" / "master_diff"
RECORDED_SCREENS_PATH = Path(__file__).resolve().parent.parent / "screens"
def get_diff():
def get_diff() -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]:
master = download.fetch_fixtures_master()
current = download.fetch_fixtures_current()
@ -34,7 +35,7 @@ def get_diff():
return removed, added, diff
def removed(screens_path, test_name):
def removed(screens_path: Path, test_name: str) -> Path:
doc = dominate.document(title=test_name)
screens = sorted(screens_path.iterdir())
@ -57,7 +58,7 @@ def removed(screens_path, test_name):
return html.write(REPORTS_PATH / "removed", doc, test_name + ".html")
def added(screens_path, test_name):
def added(screens_path: Path, test_name: str) -> Path:
doc = dominate.document(title=test_name)
screens = sorted(screens_path.iterdir())
@ -81,8 +82,12 @@ def added(screens_path, test_name):
def diff(
master_screens_path, current_screens_path, test_name, master_hash, current_hash
):
master_screens_path: Path,
current_screens_path: Path,
test_name: str,
master_hash: str,
current_hash: str,
) -> Path:
doc = dominate.document(title=test_name)
master_screens = sorted(master_screens_path.iterdir())
current_screens = sorted(current_screens_path.iterdir())
@ -109,7 +114,7 @@ def diff(
return html.write(REPORTS_PATH / "diff", doc, test_name + ".html")
def index():
def index() -> Path:
removed = list((REPORTS_PATH / "removed").iterdir())
added = list((REPORTS_PATH / "added").iterdir())
diff = list((REPORTS_PATH / "diff").iterdir())
@ -140,7 +145,7 @@ def index():
return html.write(REPORTS_PATH, doc, "index.html")
def create_dirs():
def create_dirs() -> None:
# delete the reports dir to clear previous entries and create folders
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
REPORTS_PATH.mkdir()
@ -149,7 +154,7 @@ def create_dirs():
(REPORTS_PATH / "diff").mkdir()
def create_reports():
def create_reports() -> None:
removed_tests, added_tests, diff_tests = get_diff()
@contextmanager

@ -2,6 +2,7 @@ import shutil
from datetime import datetime
from distutils.dir_util import copy_tree
from pathlib import Path
from typing import Dict
import dominate
import dominate.tags as t
@ -16,10 +17,12 @@ REPORTS_PATH = HERE / "reports" / "test"
STYLE = (HERE / "testreport.css").read_text()
SCRIPT = (HERE / "testreport.js").read_text()
ACTUAL_HASHES = {}
ACTUAL_HASHES: Dict[str, str] = {}
def document(title, actual_hash=None, index=False):
def document(
title: str, actual_hash: str = None, index: bool = False
) -> dominate.document:
doc = dominate.document(title=title)
style = t.style()
style.add_raw_string(STYLE)
@ -36,7 +39,7 @@ def document(title, actual_hash=None, index=False):
return doc
def _header(test_name, expected_hash, actual_hash):
def _header(test_name: str, expected_hash: str, actual_hash: str) -> None:
h1(test_name)
with div():
if actual_hash == expected_hash:
@ -54,7 +57,7 @@ def _header(test_name, expected_hash, actual_hash):
hr()
def clear_dir():
def clear_dir() -> None:
# delete and create the reports dir to clear previous entries
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
REPORTS_PATH.mkdir()
@ -62,7 +65,7 @@ def clear_dir():
(REPORTS_PATH / "passed").mkdir()
def index():
def index() -> Path:
passed_tests = list((REPORTS_PATH / "passed").iterdir())
failed_tests = list((REPORTS_PATH / "failed").iterdir())
@ -104,7 +107,9 @@ def index():
return html.write(REPORTS_PATH, doc, "index.html")
def failed(fixture_test_path, test_name, actual_hash, expected_hash):
def failed(
fixture_test_path: Path, test_name: str, actual_hash: str, expected_hash: str
) -> Path:
ACTUAL_HASHES[test_name] = actual_hash
doc = document(title=test_name, actual_hash=actual_hash)
@ -147,7 +152,7 @@ def failed(fixture_test_path, test_name, actual_hash, expected_hash):
return html.write(REPORTS_PATH / "failed", doc, test_name + ".html")
def passed(fixture_test_path, test_name, actual_hash):
def passed(fixture_test_path: Path, test_name: str, actual_hash: str) -> Path:
copy_tree(str(fixture_test_path / "actual"), str(fixture_test_path / "recorded"))
doc = document(title=test_name)

@ -15,8 +15,10 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import os
from typing import List, Tuple
import pytest
from _pytest.mark.structures import MarkDecorator
from ..emulators import ALL_TAGS, LOCAL_BUILD_PATHS
@ -44,7 +46,11 @@ core_only = pytest.mark.skipif(
)
def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0, 0)):
def for_all(
*args,
legacy_minimum_version: Tuple[int, int, int] = (1, 0, 0),
core_minimum_version: Tuple[int, int, int] = (2, 0, 0)
) -> "MarkDecorator":
"""Parametrizing decorator for test cases.
Usage example:
@ -98,7 +104,7 @@ def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0,
return pytest.mark.parametrize("gen, tag", all_params)
def for_tags(*args):
def for_tags(*args: Tuple[str, List[str]]) -> "MarkDecorator":
enabled_gens = SELECTED_GENS or ("core", "legacy")
return pytest.mark.parametrize(
"gen, tags", [(gen, tags) for gen, tags in args if gen in enabled_gens]

Loading…
Cancel
Save