mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-12 08:20:56 +00:00
chore(tests): add type hints to helper test functions
This commit is contained in:
parent
5d76144ef5
commit
c77e18d77c
@ -18,6 +18,7 @@ import hashlib
|
|||||||
import hmac
|
import hmac
|
||||||
import struct
|
import struct
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
import ecdsa
|
import ecdsa
|
||||||
from ecdsa.curves import SECP256k1
|
from ecdsa.curves import SECP256k1
|
||||||
@ -27,7 +28,7 @@ from ecdsa.util import number_to_string, string_to_number
|
|||||||
from trezorlib import messages, tools
|
from trezorlib import messages, tools
|
||||||
|
|
||||||
|
|
||||||
def point_to_pubkey(point):
|
def point_to_pubkey(point: Point) -> bytes:
|
||||||
order = SECP256k1.order
|
order = SECP256k1.order
|
||||||
x_str = number_to_string(point.x(), order)
|
x_str = number_to_string(point.x(), order)
|
||||||
y_str = number_to_string(point.y(), order)
|
y_str = number_to_string(point.y(), order)
|
||||||
@ -35,14 +36,14 @@ def point_to_pubkey(point):
|
|||||||
return struct.pack("B", (vk[63] & 1) + 2) + vk[0:32] # To compressed key
|
return struct.pack("B", (vk[63] & 1) + 2) + vk[0:32] # To compressed key
|
||||||
|
|
||||||
|
|
||||||
def sec_to_public_pair(pubkey):
|
def sec_to_public_pair(pubkey: bytes) -> Tuple[int, Any]:
|
||||||
"""Convert a public key in sec binary format to a public pair."""
|
"""Convert a public key in sec binary format to a public pair."""
|
||||||
x = string_to_number(pubkey[1:33])
|
x = string_to_number(pubkey[1:33])
|
||||||
sec0 = pubkey[:1]
|
sec0 = pubkey[:1]
|
||||||
if sec0 not in (b"\2", b"\3"):
|
if sec0 not in (b"\2", b"\3"):
|
||||||
raise ValueError("Compressed pubkey expected")
|
raise ValueError("Compressed pubkey expected")
|
||||||
|
|
||||||
def public_pair_for_x(generator, x, is_even):
|
def public_pair_for_x(generator, x: int, is_even: bool) -> Tuple[int, Any]:
|
||||||
curve = generator.curve()
|
curve = generator.curve()
|
||||||
p = curve.p()
|
p = curve.p()
|
||||||
alpha = (pow(x, 3, p) + curve.a() * x + curve.b()) % p
|
alpha = (pow(x, 3, p) + curve.a() * x + curve.b()) % p
|
||||||
@ -56,15 +57,15 @@ def sec_to_public_pair(pubkey):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fingerprint(pubkey):
|
def fingerprint(pubkey: bytes) -> int:
|
||||||
return string_to_number(tools.hash_160(pubkey)[:4])
|
return string_to_number(tools.hash_160(pubkey)[:4])
|
||||||
|
|
||||||
|
|
||||||
def get_address(public_node, address_type):
|
def get_address(public_node: messages.HDNodeType, address_type: int) -> str:
|
||||||
return tools.public_key_to_bc_address(public_node.public_key, address_type)
|
return tools.public_key_to_bc_address(public_node.public_key, address_type)
|
||||||
|
|
||||||
|
|
||||||
def public_ckd(public_node, n):
|
def public_ckd(public_node: messages.HDNodeType, n: List[int]):
|
||||||
if not isinstance(n, list):
|
if not isinstance(n, list):
|
||||||
raise ValueError("Parameter must be a list")
|
raise ValueError("Parameter must be a list")
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ def public_ckd(public_node, n):
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
def get_subnode(node, i):
|
def get_subnode(node: messages.HDNodeType, i: int) -> messages.HDNodeType:
|
||||||
# Public Child key derivation (CKD) algorithm of BIP32
|
# Public Child key derivation (CKD) algorithm of BIP32
|
||||||
i_as_bytes = struct.pack(">L", i)
|
i_as_bytes = struct.pack(">L", i)
|
||||||
|
|
||||||
@ -108,7 +109,7 @@ def get_subnode(node, i):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def serialize(node, version=0x0488B21E):
|
def serialize(node: messages.HDNodeType, version: int = 0x0488B21E) -> str:
|
||||||
s = b""
|
s = b""
|
||||||
s += struct.pack(">I", version)
|
s += struct.pack(">I", version)
|
||||||
s += struct.pack(">B", node.depth)
|
s += struct.pack(">B", node.depth)
|
||||||
@ -123,7 +124,7 @@ def serialize(node, version=0x0488B21E):
|
|||||||
return tools.b58encode(s)
|
return tools.b58encode(s)
|
||||||
|
|
||||||
|
|
||||||
def deserialize(xpub):
|
def deserialize(xpub: str) -> messages.HDNodeType:
|
||||||
data = tools.b58decode(xpub, None)
|
data = tools.b58decode(xpub, None)
|
||||||
|
|
||||||
if tools.btc_hash(data[:-4])[:4] != data[-4:]:
|
if tools.btc_hash(data[:-4])[:4] != data[-4:]:
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
|
from typing import Iterator, Tuple
|
||||||
|
|
||||||
DISPLAY_WIDTH = 240
|
DISPLAY_WIDTH = 240
|
||||||
DISPLAY_HEIGHT = 240
|
DISPLAY_HEIGHT = 240
|
||||||
|
|
||||||
|
|
||||||
def grid(dim, grid_cells, cell):
|
def grid(dim: int, grid_cells: int, cell: int) -> int:
|
||||||
step = dim // grid_cells
|
step = dim // grid_cells
|
||||||
ofs = step // 2
|
ofs = step // 2
|
||||||
return cell * step + ofs
|
return cell * step + ofs
|
||||||
@ -34,15 +36,15 @@ RESET_WORD_CHECK = [
|
|||||||
BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
|
BUTTON_LETTERS = ("ab", "cd", "ef", "ghij", "klm", "nopq", "rs", "tuv", "wxyz")
|
||||||
|
|
||||||
|
|
||||||
def grid35(x, y):
|
def grid35(x: int, y: int) -> Tuple[int, int]:
|
||||||
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y)
|
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 5, y)
|
||||||
|
|
||||||
|
|
||||||
def grid34(x, y):
|
def grid34(x: int, y: int) -> Tuple[int, int]:
|
||||||
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y)
|
return grid(DISPLAY_WIDTH, 3, x), grid(DISPLAY_HEIGHT, 4, y)
|
||||||
|
|
||||||
|
|
||||||
def type_word(word):
|
def type_word(word: str) -> Iterator[Tuple[int, int]]:
|
||||||
for l in word:
|
for l in word:
|
||||||
idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters)
|
idx = next(i for i, letters in enumerate(BUTTON_LETTERS) if l in letters)
|
||||||
grid_x = idx % 3
|
grid_x = idx % 3
|
||||||
|
@ -18,12 +18,18 @@ import json
|
|||||||
import os
|
import os
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Generator, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from trezorlib import btc, tools
|
from trezorlib import btc, tools
|
||||||
from trezorlib.messages import ButtonRequestType as B
|
from trezorlib.messages import ButtonRequestType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezorlib.debuglink import DebugLink, TrezorClientDebugLink as Client
|
||||||
|
from trezorlib.messages import ButtonRequest
|
||||||
|
from _pytest.mark.structures import MarkDecorator
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
# 1 2 3 4 5 6 7 8 9 10 11 12
|
# 1 2 3 4 5 6 7 8 9 10 11 12
|
||||||
@ -56,7 +62,7 @@ COMMON_FIXTURES_DIR = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parametrize_using_common_fixtures(*paths):
|
def parametrize_using_common_fixtures(*paths: str) -> "MarkDecorator":
|
||||||
fixtures = []
|
fixtures = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
fixtures.append(json.loads((COMMON_FIXTURES_DIR / path).read_text()))
|
fixtures.append(json.loads((COMMON_FIXTURES_DIR / path).read_text()))
|
||||||
@ -85,7 +91,9 @@ def parametrize_using_common_fixtures(*paths):
|
|||||||
return pytest.mark.parametrize("parameters, result", tests)
|
return pytest.mark.parametrize("parameters, result", tests)
|
||||||
|
|
||||||
|
|
||||||
def generate_entropy(strength, internal_entropy, external_entropy):
|
def generate_entropy(
|
||||||
|
strength: int, internal_entropy: bytes, external_entropy: bytes
|
||||||
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
strength - length of produced seed. One of 128, 192, 256
|
strength - length of produced seed. One of 128, 192, 256
|
||||||
random - binary stream of random data from external HRNG
|
random - binary stream of random data from external HRNG
|
||||||
@ -116,7 +124,12 @@ def generate_entropy(strength, internal_entropy, external_entropy):
|
|||||||
return entropy_stripped
|
return entropy_stripped
|
||||||
|
|
||||||
|
|
||||||
def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
def recovery_enter_shares(
|
||||||
|
debug: "DebugLink",
|
||||||
|
shares: List[str],
|
||||||
|
groups: bool = False,
|
||||||
|
click_info: bool = False,
|
||||||
|
) -> Generator[None, "ButtonRequest", None]:
|
||||||
"""Perform the recovery flow for a set of Shamir shares.
|
"""Perform the recovery flow for a set of Shamir shares.
|
||||||
|
|
||||||
For use in an input flow function.
|
For use in an input flow function.
|
||||||
@ -134,7 +147,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
|||||||
debug.press_yes()
|
debug.press_yes()
|
||||||
# Input word number
|
# Input word number
|
||||||
br = yield
|
br = yield
|
||||||
assert br.code == B.MnemonicWordCount
|
assert br.code == ButtonRequestType.MnemonicWordCount
|
||||||
debug.input(str(word_count))
|
debug.input(str(word_count))
|
||||||
# Homescreen - proceed to share entry
|
# Homescreen - proceed to share entry
|
||||||
yield
|
yield
|
||||||
@ -142,7 +155,7 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
|||||||
# Enter shares
|
# Enter shares
|
||||||
for share in shares:
|
for share in shares:
|
||||||
br = yield
|
br = yield
|
||||||
assert br.code == B.MnemonicInput
|
assert br.code == ButtonRequestType.MnemonicInput
|
||||||
# Enter mnemonic words
|
# Enter mnemonic words
|
||||||
for word in share.split(" "):
|
for word in share.split(" "):
|
||||||
debug.input(word)
|
debug.input(word)
|
||||||
@ -167,7 +180,9 @@ def recovery_enter_shares(debug, shares, groups=False, click_info=False):
|
|||||||
debug.press_yes()
|
debug.press_yes()
|
||||||
|
|
||||||
|
|
||||||
def click_through(debug, screens, code=None):
|
def click_through(
|
||||||
|
debug: "DebugLink", screens: int, code: ButtonRequestType = None
|
||||||
|
) -> Generator[None, "ButtonRequest", None]:
|
||||||
"""Click through N dialog screens.
|
"""Click through N dialog screens.
|
||||||
|
|
||||||
For use in an input flow function.
|
For use in an input flow function.
|
||||||
@ -178,7 +193,7 @@ def click_through(debug, screens, code=None):
|
|||||||
# 2. Backup your seed
|
# 2. Backup your seed
|
||||||
# 3. Confirm warning
|
# 3. Confirm warning
|
||||||
# 4. Shares info
|
# 4. Shares info
|
||||||
yield from click_through(client.debug, screens=4, code=B.ResetDevice)
|
yield from click_through(client.debug, screens=4, code=ButtonRequestType.ResetDevice)
|
||||||
"""
|
"""
|
||||||
for _ in range(screens):
|
for _ in range(screens):
|
||||||
received = yield
|
received = yield
|
||||||
@ -187,7 +202,9 @@ def click_through(debug, screens, code=None):
|
|||||||
debug.press_yes()
|
debug.press_yes()
|
||||||
|
|
||||||
|
|
||||||
def read_and_confirm_mnemonic(debug, choose_wrong=False):
|
def read_and_confirm_mnemonic(
|
||||||
|
debug: "DebugLink", choose_wrong: bool = False
|
||||||
|
) -> Generator[None, "ButtonRequest", Optional[str]]:
|
||||||
"""Read a given number of mnemonic words from Trezor T screen and correctly
|
"""Read a given number of mnemonic words from Trezor T screen and correctly
|
||||||
answer confirmation questions. Return the full mnemonic.
|
answer confirmation questions. Return the full mnemonic.
|
||||||
|
|
||||||
@ -201,6 +218,7 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
|
|||||||
"""
|
"""
|
||||||
mnemonic = []
|
mnemonic = []
|
||||||
br = yield
|
br = yield
|
||||||
|
assert br.pages is not None
|
||||||
for _ in range(br.pages - 1):
|
for _ in range(br.pages - 1):
|
||||||
mnemonic.extend(debug.read_reset_word().split())
|
mnemonic.extend(debug.read_reset_word().split())
|
||||||
debug.swipe_up(wait=True)
|
debug.swipe_up(wait=True)
|
||||||
@ -221,13 +239,15 @@ def read_and_confirm_mnemonic(debug, choose_wrong=False):
|
|||||||
return " ".join(mnemonic)
|
return " ".join(mnemonic)
|
||||||
|
|
||||||
|
|
||||||
def get_test_address(client):
|
def get_test_address(client: "Client") -> str:
|
||||||
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
|
"""Fetch a testnet address on a fixed path. Useful to make a pin/passphrase
|
||||||
protected call, or to identify the root secret (seed+passphrase)"""
|
protected call, or to identify the root secret (seed+passphrase)"""
|
||||||
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
|
return btc.get_address(client, "Testnet", TEST_ADDRESS_N)
|
||||||
|
|
||||||
|
|
||||||
def assert_tx_matches(serialized_tx: bytes, hash_link: str, tx_hex: str = None) -> None:
|
def assert_tx_matches(
|
||||||
|
serialized_tx: bytes, hash_link: str, tx_hex: str = None
|
||||||
|
) -> None:
|
||||||
"""Verifies if a transaction is correctly formed."""
|
"""Verifies if a transaction is correctly formed."""
|
||||||
hash_str = hash_link.split("/")[-1]
|
hash_str = hash_link.split("/")[-1]
|
||||||
assert tools.tx_hash(serialized_tx).hex() == hash_str
|
assert tools.tx_hash(serialized_tx).hex() == hash_str
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -27,12 +28,17 @@ from . import ui_tests
|
|||||||
from .device_handler import BackgroundDeviceHandler
|
from .device_handler import BackgroundDeviceHandler
|
||||||
from .ui_tests.reporting import testreport
|
from .ui_tests.reporting import testreport
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from _pytest.config import Config
|
||||||
|
from _pytest.config.argparsing import Parser
|
||||||
|
from _pytest.terminal import TerminalReporter
|
||||||
|
|
||||||
# So that we see details of failed asserts from this module
|
# So that we see details of failed asserts from this module
|
||||||
pytest.register_assert_rewrite("tests.common")
|
pytest.register_assert_rewrite("tests.common")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def _raw_client(request):
|
def _raw_client(request: pytest.FixtureRequest) -> TrezorClientDebugLink:
|
||||||
path = os.environ.get("TREZOR_PATH")
|
path = os.environ.get("TREZOR_PATH")
|
||||||
interact = int(os.environ.get("INTERACT", 0))
|
interact = int(os.environ.get("INTERACT", 0))
|
||||||
if path:
|
if path:
|
||||||
@ -56,7 +62,9 @@ def _raw_client(request):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
def client(request, _raw_client):
|
def client(
|
||||||
|
request: pytest.FixtureRequest, _raw_client: TrezorClientDebugLink
|
||||||
|
) -> TrezorClientDebugLink:
|
||||||
"""Client fixture.
|
"""Client fixture.
|
||||||
|
|
||||||
Every test function that requires a client instance will get it from here.
|
Every test function that requires a client instance will get it from here.
|
||||||
@ -156,20 +164,20 @@ def client(request, _raw_client):
|
|||||||
_raw_client.close()
|
_raw_client.close()
|
||||||
|
|
||||||
|
|
||||||
def pytest_sessionstart(session):
|
def pytest_sessionstart(session: pytest.Session) -> None:
|
||||||
ui_tests.read_fixtures()
|
ui_tests.read_fixtures()
|
||||||
if session.config.getoption("ui") == "test":
|
if session.config.getoption("ui") == "test":
|
||||||
testreport.clear_dir()
|
testreport.clear_dir()
|
||||||
|
|
||||||
|
|
||||||
def _should_write_ui_report(exitstatus):
|
def _should_write_ui_report(exitstatus: pytest.ExitCode) -> bool:
|
||||||
# generate UI report and check missing only if pytest is exitting cleanly
|
# generate UI report and check missing only if pytest is exitting cleanly
|
||||||
# I.e., the test suite passed or failed (as opposed to ctrl+c break, internal error,
|
# I.e., the test suite passed or failed (as opposed to ctrl+c break, internal error,
|
||||||
# etc.)
|
# etc.)
|
||||||
return exitstatus in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED)
|
return exitstatus in (pytest.ExitCode.OK, pytest.ExitCode.TESTS_FAILED)
|
||||||
|
|
||||||
|
|
||||||
def pytest_sessionfinish(session, exitstatus):
|
def pytest_sessionfinish(session: pytest.Session, exitstatus: pytest.ExitCode) -> None:
|
||||||
if not _should_write_ui_report(exitstatus):
|
if not _should_write_ui_report(exitstatus):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -183,7 +191,9 @@ def pytest_sessionfinish(session, exitstatus):
|
|||||||
ui_tests.write_fixtures(missing)
|
ui_tests.write_fixtures(missing)
|
||||||
|
|
||||||
|
|
||||||
def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
def pytest_terminal_summary(
|
||||||
|
terminalreporter: "TerminalReporter", exitstatus: pytest.ExitCode, config: "Config"
|
||||||
|
) -> None:
|
||||||
println = terminalreporter.write_line
|
println = terminalreporter.write_line
|
||||||
println("")
|
println("")
|
||||||
|
|
||||||
@ -213,7 +223,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
|||||||
println("")
|
println("")
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser: "Parser") -> None:
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--ui",
|
"--ui",
|
||||||
action="store",
|
action="store",
|
||||||
@ -229,7 +239,7 @@ def pytest_addoption(parser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config: "Config") -> None:
|
||||||
"""Called at testsuite setup time.
|
"""Called at testsuite setup time.
|
||||||
|
|
||||||
Registers known markers, enables verbose output if requested.
|
Registers known markers, enables verbose output if requested.
|
||||||
@ -253,7 +263,7 @@ def pytest_configure(config):
|
|||||||
log.enable_debug_output()
|
log.enable_debug_output()
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_setup(item):
|
def pytest_runtest_setup(item: pytest.Item) -> None:
|
||||||
"""Called for each test item (class, individual tests).
|
"""Called for each test item (class, individual tests).
|
||||||
|
|
||||||
Ensures that altcoin tests are skipped, and that no test is skipped on
|
Ensures that altcoin tests are skipped, and that no test is skipped on
|
||||||
@ -267,7 +277,7 @@ def pytest_runtest_setup(item):
|
|||||||
pytest.skip("Skipping altcoin test")
|
pytest.skip("Skipping altcoin test")
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_teardown(item):
|
def pytest_runtest_teardown(item: pytest.Item) -> None:
|
||||||
"""Called after a test item finishes.
|
"""Called after a test item finishes.
|
||||||
|
|
||||||
Dumps the current UI test report HTML.
|
Dumps the current UI test report HTML.
|
||||||
@ -277,7 +287,7 @@ def pytest_runtest_teardown(item):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||||
def pytest_runtest_makereport(item, call):
|
def pytest_runtest_makereport(item: pytest.Item, call) -> None:
|
||||||
# Make test results available in fixtures.
|
# Make test results available in fixtures.
|
||||||
# See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
# See https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures
|
||||||
# The device_handler fixture uses this as 'request.node.rep_call.passed' attribute,
|
# The device_handler fixture uses this as 'request.node.rep_call.passed' attribute,
|
||||||
@ -288,7 +298,9 @@ def pytest_runtest_makereport(item, call):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def device_handler(client, request):
|
def device_handler(
|
||||||
|
client: TrezorClientDebugLink, request: pytest.FixtureRequest
|
||||||
|
) -> None:
|
||||||
device_handler = BackgroundDeviceHandler(client)
|
device_handler = BackgroundDeviceHandler(client)
|
||||||
yield device_handler
|
yield device_handler
|
||||||
|
|
||||||
@ -299,5 +311,5 @@ def device_handler(client, request):
|
|||||||
|
|
||||||
# if test finished, make sure all background tasks are done
|
# if test finished, make sure all background tasks are done
|
||||||
finalized_ok = device_handler.check_finalize()
|
finalized_ok = device_handler.check_finalize()
|
||||||
if request.node.rep_call.passed and not finalized_ok:
|
if request.node.rep_call.passed and not finalized_ok: # type: ignore [rep_call must exist]
|
||||||
raise RuntimeError("Test did not check result of background task")
|
raise RuntimeError("Test did not check result of background task")
|
||||||
|
@ -1,8 +1,15 @@
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezorlib.client import PASSPHRASE_ON_DEVICE
|
from trezorlib.client import PASSPHRASE_ON_DEVICE
|
||||||
from trezorlib.transport import udp
|
from trezorlib.transport import udp
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from trezorlib.messages import Features
|
||||||
|
from trezorlib.debuglink import TrezorClientDebugLink, DebugLink
|
||||||
|
from trezorlib._internal.emulator import Emulator
|
||||||
|
|
||||||
|
|
||||||
udp.SOCKET_TIMEOUT = 0.1
|
udp.SOCKET_TIMEOUT = 0.1
|
||||||
|
|
||||||
|
|
||||||
@ -26,13 +33,13 @@ class NullUI:
|
|||||||
class BackgroundDeviceHandler:
|
class BackgroundDeviceHandler:
|
||||||
_pool = ThreadPoolExecutor()
|
_pool = ThreadPoolExecutor()
|
||||||
|
|
||||||
def __init__(self, client):
|
def __init__(self, client: "TrezorClientDebugLink") -> None:
|
||||||
self._configure_client(client)
|
self._configure_client(client)
|
||||||
self.task = None
|
self.task = None
|
||||||
|
|
||||||
def _configure_client(self, client):
|
def _configure_client(self, client: "TrezorClientDebugLink") -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
self.client.ui = NullUI
|
self.client.ui = NullUI # type: ignore [NullUI is OK UI]
|
||||||
self.client.watch_layout(True)
|
self.client.watch_layout(True)
|
||||||
|
|
||||||
def run(self, function, *args, **kwargs):
|
def run(self, function, *args, **kwargs):
|
||||||
@ -40,7 +47,7 @@ class BackgroundDeviceHandler:
|
|||||||
raise RuntimeError("Wait for previous task first")
|
raise RuntimeError("Wait for previous task first")
|
||||||
self.task = self._pool.submit(function, self.client, *args, **kwargs)
|
self.task = self._pool.submit(function, self.client, *args, **kwargs)
|
||||||
|
|
||||||
def kill_task(self):
|
def kill_task(self) -> None:
|
||||||
if self.task is not None:
|
if self.task is not None:
|
||||||
# Force close the client, which should raise an exception in a client
|
# Force close the client, which should raise an exception in a client
|
||||||
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
# waiting on IO. Does not work over Bridge, because bridge doesn't have
|
||||||
@ -53,11 +60,11 @@ class BackgroundDeviceHandler:
|
|||||||
pass
|
pass
|
||||||
self.task = None
|
self.task = None
|
||||||
|
|
||||||
def restart(self, emulator):
|
def restart(self, emulator: "Emulator"):
|
||||||
# TODO handle actual restart as well
|
# TODO handle actual restart as well
|
||||||
self.kill_task()
|
self.kill_task()
|
||||||
emulator.restart()
|
emulator.restart()
|
||||||
self._configure_client(emulator.client)
|
self._configure_client(emulator.client) # type: ignore [client cannot be None]
|
||||||
|
|
||||||
def result(self):
|
def result(self):
|
||||||
if self.task is None:
|
if self.task is None:
|
||||||
@ -67,25 +74,25 @@ class BackgroundDeviceHandler:
|
|||||||
finally:
|
finally:
|
||||||
self.task = None
|
self.task = None
|
||||||
|
|
||||||
def features(self):
|
def features(self) -> "Features":
|
||||||
if self.task is not None:
|
if self.task is not None:
|
||||||
raise RuntimeError("Cannot query features while task is running")
|
raise RuntimeError("Cannot query features while task is running")
|
||||||
self.client.init_device()
|
self.client.init_device()
|
||||||
return self.client.features
|
return self.client.features
|
||||||
|
|
||||||
def debuglink(self):
|
def debuglink(self) -> "DebugLink":
|
||||||
return self.client.debug
|
return self.client.debug
|
||||||
|
|
||||||
def check_finalize(self):
|
def check_finalize(self) -> bool:
|
||||||
if self.task is not None:
|
if self.task is not None:
|
||||||
self.kill_task()
|
self.kill_task()
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> "BackgroundDeviceHandler":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||||
finalized_ok = self.check_finalize()
|
finalized_ok = self.check_finalize()
|
||||||
if exc_type is None and not finalized_ok:
|
if exc_type is None and not finalized_ok:
|
||||||
raise RuntimeError("Exit while task is unfinished")
|
raise RuntimeError("Exit while task is unfinished")
|
||||||
|
@ -17,8 +17,9 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from trezorlib._internal.emulator import CoreEmulator, LegacyEmulator
|
from trezorlib._internal.emulator import CoreEmulator, Emulator, LegacyEmulator
|
||||||
|
|
||||||
ROOT = Path(__file__).resolve().parent.parent
|
ROOT = Path(__file__).resolve().parent.parent
|
||||||
BINDIR = ROOT / "tests" / "emulators"
|
BINDIR = ROOT / "tests" / "emulators"
|
||||||
@ -33,18 +34,18 @@ CORE_SRC_DIR = ROOT / "core" / "src"
|
|||||||
ENV = {"SDL_VIDEODRIVER": "dummy"}
|
ENV = {"SDL_VIDEODRIVER": "dummy"}
|
||||||
|
|
||||||
|
|
||||||
def check_version(tag, version_tuple):
|
def check_version(tag: str, version_tuple: Tuple[int, int, int]) -> None:
|
||||||
if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3:
|
if tag is not None and tag.startswith("v") and len(tag.split(".")) == 3:
|
||||||
version = ".".join(str(i) for i in version_tuple)
|
version = ".".join(str(i) for i in version_tuple)
|
||||||
if tag[1:] != version:
|
if tag[1:] != version:
|
||||||
raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}")
|
raise RuntimeError(f"Version mismatch: tag {tag} reports version {version}")
|
||||||
|
|
||||||
|
|
||||||
def filename_from_tag(gen, tag):
|
def filename_from_tag(gen: str, tag: str) -> Path:
|
||||||
return BINDIR / f"trezor-emu-{gen}-{tag}"
|
return BINDIR / f"trezor-emu-{gen}-{tag}"
|
||||||
|
|
||||||
|
|
||||||
def get_tags():
|
def get_tags() -> Dict[str, List[str]]:
|
||||||
files = list(BINDIR.iterdir())
|
files = list(BINDIR.iterdir())
|
||||||
if not files:
|
if not files:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -66,7 +67,7 @@ ALL_TAGS = get_tags()
|
|||||||
|
|
||||||
|
|
||||||
class EmulatorWrapper:
|
class EmulatorWrapper:
|
||||||
def __init__(self, gen, tag=None, storage=None):
|
def __init__(self, gen: str, tag: str = None, storage: bytes = None) -> None:
|
||||||
if tag is not None:
|
if tag is not None:
|
||||||
executable = filename_from_tag(gen, tag)
|
executable = filename_from_tag(gen, tag)
|
||||||
else:
|
else:
|
||||||
@ -96,11 +97,15 @@ class EmulatorWrapper:
|
|||||||
workdir=workdir,
|
workdir=workdir,
|
||||||
headless=True,
|
headless=True,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unrecognized gen - {gen} - only 'core' and 'legacy' supported"
|
||||||
|
)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self) -> Emulator:
|
||||||
self.emulator.start()
|
self.emulator.start()
|
||||||
return self.emulator
|
return self.emulator
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||||
self.emulator.stop()
|
self.emulator.stop()
|
||||||
self.profile_dir.cleanup()
|
self.profile_dir.cleanup()
|
||||||
|
@ -5,9 +5,9 @@ import multiprocessing
|
|||||||
import os
|
import os
|
||||||
import posixpath
|
import posixpath
|
||||||
import time
|
import time
|
||||||
import urllib
|
|
||||||
import webbrowser
|
import webbrowser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@ -16,16 +16,16 @@ TEST_RESULT_PATH = ROOT / "tests" / "ui_tests" / "reporting" / "reports" / "test
|
|||||||
|
|
||||||
|
|
||||||
class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
|
class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
def end_headers(self):
|
def end_headers(self) -> None:
|
||||||
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
|
self.send_header("Cache-Control", "no-cache, no-store, must-revalidate")
|
||||||
self.send_header("Pragma", "no-cache")
|
self.send_header("Pragma", "no-cache")
|
||||||
self.send_header("Expires", "0")
|
self.send_header("Expires", "0")
|
||||||
return super().end_headers()
|
return super().end_headers()
|
||||||
|
|
||||||
def log_message(self, format, *args):
|
def log_message(self, format, *args) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def translate_path(self, path):
|
def translate_path(self, path: str) -> str:
|
||||||
# XXX
|
# XXX
|
||||||
# Copy-pasted from Python 3.8 BaseHTTPRequestHandler so that we can inject
|
# Copy-pasted from Python 3.8 BaseHTTPRequestHandler so that we can inject
|
||||||
# the `directory` parameter.
|
# the `directory` parameter.
|
||||||
@ -36,9 +36,9 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
# Don't forget explicit trailing slash when normalizing. Issue17324
|
# Don't forget explicit trailing slash when normalizing. Issue17324
|
||||||
trailing_slash = path.rstrip().endswith("/")
|
trailing_slash = path.rstrip().endswith("/")
|
||||||
try:
|
try:
|
||||||
path = urllib.parse.unquote(path, errors="surrogatepass")
|
path = unquote(path, errors="surrogatepass")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
path = urllib.parse.unquote(path)
|
path = unquote(path)
|
||||||
path = posixpath.normpath(path)
|
path = posixpath.normpath(path)
|
||||||
words = path.split("/")
|
words = path.split("/")
|
||||||
words = filter(None, words)
|
words = filter(None, words)
|
||||||
@ -53,13 +53,13 @@ class NoCacheRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def launch_http_server(port):
|
def launch_http_server(port: int) -> None:
|
||||||
http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port)
|
http.server.test(HandlerClass=NoCacheRequestHandler, bind="localhost", port=port) # type: ignore [test is defined]
|
||||||
|
|
||||||
|
|
||||||
@click.command()
|
@click.command()
|
||||||
@click.option("-p", "--port", type=int, default=8000)
|
@click.option("-p", "--port", type=int, default=8000)
|
||||||
def main(port):
|
def main(port: int):
|
||||||
httpd = multiprocessing.Process(target=launch_http_server, args=(port,))
|
httpd = multiprocessing.Process(target=launch_http_server, args=(port,))
|
||||||
httpd.start()
|
httpd.start()
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
@ -122,7 +122,7 @@ def cli(tx, coin_name):
|
|||||||
tx_dict = protobuf.to_dict(tx_proto)
|
tx_dict = protobuf.to_dict(tx_proto)
|
||||||
tx_json = json.dumps(tx_dict, sort_keys=True, indent=2) + "\n"
|
tx_json = json.dumps(tx_dict, sort_keys=True, indent=2) + "\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise click.ClickException(e) from e
|
raise click.ClickException(str(e)) from e
|
||||||
|
|
||||||
cache_dir = CACHE_PATH / coin_name
|
cache_dir = CACHE_PATH / coin_name
|
||||||
if not cache_dir.exists():
|
if not cache_dir.exists():
|
||||||
|
@ -4,23 +4,26 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, Generator, Set
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.outcomes import Failed
|
from _pytest.outcomes import Failed
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
|
|
||||||
from .reporting import testreport
|
from .reporting import testreport
|
||||||
|
|
||||||
UI_TESTS_DIR = Path(__file__).resolve().parent
|
UI_TESTS_DIR = Path(__file__).resolve().parent
|
||||||
SCREENS_DIR = UI_TESTS_DIR / "screens"
|
SCREENS_DIR = UI_TESTS_DIR / "screens"
|
||||||
HASH_FILE = UI_TESTS_DIR / "fixtures.json"
|
HASH_FILE = UI_TESTS_DIR / "fixtures.json"
|
||||||
SUGGESTION_FILE = UI_TESTS_DIR / "fixtures.suggestion.json"
|
SUGGESTION_FILE = UI_TESTS_DIR / "fixtures.suggestion.json"
|
||||||
FILE_HASHES = {}
|
FILE_HASHES: Dict[str, str] = {}
|
||||||
ACTUAL_HASHES = {}
|
ACTUAL_HASHES: Dict[str, str] = {}
|
||||||
PROCESSED = set()
|
PROCESSED: Set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
def get_test_name(node_id):
|
def get_test_name(node_id: str) -> str:
|
||||||
# Test item name is usually function name, but when parametrization is used,
|
# Test item name is usually function name, but when parametrization is used,
|
||||||
# parameters are also part of the name. Some functions have very long parameter
|
# parameters are also part of the name. Some functions have very long parameter
|
||||||
# names (tx hashes etc) that run out of maximum allowable filename length, so
|
# names (tx hashes etc) that run out of maximum allowable filename length, so
|
||||||
@ -34,14 +37,14 @@ def get_test_name(node_id):
|
|||||||
return new_name[:91] + "-" + hashlib.sha256(new_name.encode()).hexdigest()[:8]
|
return new_name[:91] + "-" + hashlib.sha256(new_name.encode()).hexdigest()[:8]
|
||||||
|
|
||||||
|
|
||||||
def _process_recorded(screen_path, test_name):
|
def _process_recorded(screen_path: Path, test_name: str) -> None:
|
||||||
# calculate hash
|
# calculate hash
|
||||||
FILE_HASHES[test_name] = _hash_files(screen_path)
|
FILE_HASHES[test_name] = _hash_files(screen_path)
|
||||||
_rename_records(screen_path)
|
_rename_records(screen_path)
|
||||||
PROCESSED.add(test_name)
|
PROCESSED.add(test_name)
|
||||||
|
|
||||||
|
|
||||||
def _rename_records(screen_path):
|
def _rename_records(screen_path: Path) -> None:
|
||||||
# rename screenshots
|
# rename screenshots
|
||||||
for index, record in enumerate(sorted(screen_path.iterdir())):
|
for index, record in enumerate(sorted(screen_path.iterdir())):
|
||||||
record.replace(screen_path / f"{index:08}.png")
|
record.replace(screen_path / f"{index:08}.png")
|
||||||
@ -65,7 +68,7 @@ def _get_bytes_from_png(png_file: str) -> bytes:
|
|||||||
return Image.open(png_file).tobytes()
|
return Image.open(png_file).tobytes()
|
||||||
|
|
||||||
|
|
||||||
def _process_tested(fixture_test_path, test_name):
|
def _process_tested(fixture_test_path: Path, test_name: str) -> None:
|
||||||
PROCESSED.add(test_name)
|
PROCESSED.add(test_name)
|
||||||
|
|
||||||
actual_path = fixture_test_path / "actual"
|
actual_path = fixture_test_path / "actual"
|
||||||
@ -79,6 +82,7 @@ def _process_tested(fixture_test_path, test_name):
|
|||||||
pytest.fail(f"Hash of {test_name} not found in fixtures.json")
|
pytest.fail(f"Hash of {test_name} not found in fixtures.json")
|
||||||
|
|
||||||
if actual_hash != expected_hash:
|
if actual_hash != expected_hash:
|
||||||
|
assert expected_hash is not None
|
||||||
file_path = testreport.failed(
|
file_path = testreport.failed(
|
||||||
fixture_test_path, test_name, actual_hash, expected_hash
|
fixture_test_path, test_name, actual_hash, expected_hash
|
||||||
)
|
)
|
||||||
@ -94,7 +98,9 @@ def _process_tested(fixture_test_path, test_name):
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def screen_recording(client, request):
|
def screen_recording(
|
||||||
|
client: Client, request: pytest.FixtureRequest
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
test_ui = request.config.getoption("ui")
|
test_ui = request.config.getoption("ui")
|
||||||
test_name = get_test_name(request.node.nodeid)
|
test_name = get_test_name(request.node.nodeid)
|
||||||
screens_test_path = SCREENS_DIR / test_name
|
screens_test_path = SCREENS_DIR / test_name
|
||||||
@ -126,26 +132,26 @@ def screen_recording(client, request):
|
|||||||
_process_tested(screens_test_path, test_name)
|
_process_tested(screens_test_path, test_name)
|
||||||
|
|
||||||
|
|
||||||
def list_missing():
|
def list_missing() -> Set[str]:
|
||||||
return set(FILE_HASHES.keys()) - PROCESSED
|
return set(FILE_HASHES.keys()) - PROCESSED
|
||||||
|
|
||||||
|
|
||||||
def read_fixtures():
|
def read_fixtures() -> None:
|
||||||
if not HASH_FILE.exists():
|
if not HASH_FILE.exists():
|
||||||
raise ValueError("File fixtures.json not found.")
|
raise ValueError("File fixtures.json not found.")
|
||||||
global FILE_HASHES
|
global FILE_HASHES
|
||||||
FILE_HASHES = json.loads(HASH_FILE.read_text())
|
FILE_HASHES = json.loads(HASH_FILE.read_text())
|
||||||
|
|
||||||
|
|
||||||
def write_fixtures(remove_missing: bool):
|
def write_fixtures(remove_missing: bool) -> None:
|
||||||
HASH_FILE.write_text(_get_fixtures_content(FILE_HASHES, remove_missing))
|
HASH_FILE.write_text(_get_fixtures_content(FILE_HASHES, remove_missing))
|
||||||
|
|
||||||
|
|
||||||
def write_fixtures_suggestion(remove_missing: bool):
|
def write_fixtures_suggestion(remove_missing: bool) -> None:
|
||||||
SUGGESTION_FILE.write_text(_get_fixtures_content(ACTUAL_HASHES, remove_missing))
|
SUGGESTION_FILE.write_text(_get_fixtures_content(ACTUAL_HASHES, remove_missing))
|
||||||
|
|
||||||
|
|
||||||
def _get_fixtures_content(fixtures: dict, remove_missing: bool):
|
def _get_fixtures_content(fixtures: Dict[str, str], remove_missing: bool) -> str:
|
||||||
if remove_missing:
|
if remove_missing:
|
||||||
fixtures = {i: fixtures[i] for i in PROCESSED}
|
fixtures = {i: fixtures[i] for i in PROCESSED}
|
||||||
else:
|
else:
|
||||||
@ -154,7 +160,7 @@ def _get_fixtures_content(fixtures: dict, remove_missing: bool):
|
|||||||
return json.dumps(fixtures, indent="", sort_keys=True) + "\n"
|
return json.dumps(fixtures, indent="", sort_keys=True) + "\n"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
read_fixtures()
|
read_fixtures()
|
||||||
for record in SCREENS_DIR.iterdir():
|
for record in SCREENS_DIR.iterdir():
|
||||||
if not (record / "actual").exists():
|
if not (record / "actual").exists():
|
||||||
|
@ -14,7 +14,7 @@ FIXTURES_CURRENT = Path(__file__).resolve().parent.parent / "fixtures.json"
|
|||||||
_dns_failed = False
|
_dns_failed = False
|
||||||
|
|
||||||
|
|
||||||
def fetch_recorded(hash, path):
|
def fetch_recorded(hash: str, path: Path) -> None:
|
||||||
global _dns_failed
|
global _dns_failed
|
||||||
|
|
||||||
if _dns_failed:
|
if _dns_failed:
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
import base64
|
import base64
|
||||||
import filecmp
|
import filecmp
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from dominate.tags import a, i, img, table, td, th, tr
|
from dominate.tags import a, i, img, table, td, th, tr
|
||||||
|
|
||||||
|
|
||||||
def report_links(tests, reports_path, actual_hashes=None):
|
def report_links(
|
||||||
|
tests: List[Path], reports_path: Path, actual_hashes: Dict[str, str] = None
|
||||||
|
) -> None:
|
||||||
if actual_hashes is None:
|
if actual_hashes is None:
|
||||||
actual_hashes = {}
|
actual_hashes = {}
|
||||||
|
|
||||||
@ -21,12 +25,12 @@ def report_links(tests, reports_path, actual_hashes=None):
|
|||||||
td(a(test.name, href=path))
|
td(a(test.name, href=path))
|
||||||
|
|
||||||
|
|
||||||
def write(fixture_test_path, doc, filename):
|
def write(fixture_test_path: Path, doc, filename: str) -> Path:
|
||||||
(fixture_test_path / filename).write_text(doc.render())
|
(fixture_test_path / filename).write_text(doc.render())
|
||||||
return fixture_test_path / filename
|
return fixture_test_path / filename
|
||||||
|
|
||||||
|
|
||||||
def image(src):
|
def image(src: Path) -> None:
|
||||||
with td():
|
with td():
|
||||||
if src:
|
if src:
|
||||||
# open image file
|
# open image file
|
||||||
@ -41,7 +45,7 @@ def image(src):
|
|||||||
i("missing")
|
i("missing")
|
||||||
|
|
||||||
|
|
||||||
def diff_table(left_screens, right_screens):
|
def diff_table(left_screens: List[Path], right_screens: List[Path]) -> None:
|
||||||
for left, right in zip_longest(left_screens, right_screens):
|
for left, right in zip_longest(left_screens, right_screens):
|
||||||
if left and right and filecmp.cmp(right, left):
|
if left and right and filecmp.cmp(right, left):
|
||||||
background = "white"
|
background = "white"
|
||||||
|
@ -2,6 +2,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
import dominate
|
import dominate
|
||||||
from dominate.tags import br, h1, h2, hr, i, p, table, td, th, tr
|
from dominate.tags import br, h1, h2, hr, i, p, table, td, th, tr
|
||||||
@ -14,7 +15,7 @@ REPORTS_PATH = Path(__file__).resolve().parent / "reports" / "master_diff"
|
|||||||
RECORDED_SCREENS_PATH = Path(__file__).resolve().parent.parent / "screens"
|
RECORDED_SCREENS_PATH = Path(__file__).resolve().parent.parent / "screens"
|
||||||
|
|
||||||
|
|
||||||
def get_diff():
|
def get_diff() -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]:
|
||||||
master = download.fetch_fixtures_master()
|
master = download.fetch_fixtures_master()
|
||||||
current = download.fetch_fixtures_current()
|
current = download.fetch_fixtures_current()
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ def get_diff():
|
|||||||
return removed, added, diff
|
return removed, added, diff
|
||||||
|
|
||||||
|
|
||||||
def removed(screens_path, test_name):
|
def removed(screens_path: Path, test_name: str) -> Path:
|
||||||
doc = dominate.document(title=test_name)
|
doc = dominate.document(title=test_name)
|
||||||
screens = sorted(screens_path.iterdir())
|
screens = sorted(screens_path.iterdir())
|
||||||
|
|
||||||
@ -57,7 +58,7 @@ def removed(screens_path, test_name):
|
|||||||
return html.write(REPORTS_PATH / "removed", doc, test_name + ".html")
|
return html.write(REPORTS_PATH / "removed", doc, test_name + ".html")
|
||||||
|
|
||||||
|
|
||||||
def added(screens_path, test_name):
|
def added(screens_path: Path, test_name: str) -> Path:
|
||||||
doc = dominate.document(title=test_name)
|
doc = dominate.document(title=test_name)
|
||||||
screens = sorted(screens_path.iterdir())
|
screens = sorted(screens_path.iterdir())
|
||||||
|
|
||||||
@ -81,8 +82,12 @@ def added(screens_path, test_name):
|
|||||||
|
|
||||||
|
|
||||||
def diff(
|
def diff(
|
||||||
master_screens_path, current_screens_path, test_name, master_hash, current_hash
|
master_screens_path: Path,
|
||||||
):
|
current_screens_path: Path,
|
||||||
|
test_name: str,
|
||||||
|
master_hash: str,
|
||||||
|
current_hash: str,
|
||||||
|
) -> Path:
|
||||||
doc = dominate.document(title=test_name)
|
doc = dominate.document(title=test_name)
|
||||||
master_screens = sorted(master_screens_path.iterdir())
|
master_screens = sorted(master_screens_path.iterdir())
|
||||||
current_screens = sorted(current_screens_path.iterdir())
|
current_screens = sorted(current_screens_path.iterdir())
|
||||||
@ -109,7 +114,7 @@ def diff(
|
|||||||
return html.write(REPORTS_PATH / "diff", doc, test_name + ".html")
|
return html.write(REPORTS_PATH / "diff", doc, test_name + ".html")
|
||||||
|
|
||||||
|
|
||||||
def index():
|
def index() -> Path:
|
||||||
removed = list((REPORTS_PATH / "removed").iterdir())
|
removed = list((REPORTS_PATH / "removed").iterdir())
|
||||||
added = list((REPORTS_PATH / "added").iterdir())
|
added = list((REPORTS_PATH / "added").iterdir())
|
||||||
diff = list((REPORTS_PATH / "diff").iterdir())
|
diff = list((REPORTS_PATH / "diff").iterdir())
|
||||||
@ -140,7 +145,7 @@ def index():
|
|||||||
return html.write(REPORTS_PATH, doc, "index.html")
|
return html.write(REPORTS_PATH, doc, "index.html")
|
||||||
|
|
||||||
|
|
||||||
def create_dirs():
|
def create_dirs() -> None:
|
||||||
# delete the reports dir to clear previous entries and create folders
|
# delete the reports dir to clear previous entries and create folders
|
||||||
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
|
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
|
||||||
REPORTS_PATH.mkdir()
|
REPORTS_PATH.mkdir()
|
||||||
@ -149,7 +154,7 @@ def create_dirs():
|
|||||||
(REPORTS_PATH / "diff").mkdir()
|
(REPORTS_PATH / "diff").mkdir()
|
||||||
|
|
||||||
|
|
||||||
def create_reports():
|
def create_reports() -> None:
|
||||||
removed_tests, added_tests, diff_tests = get_diff()
|
removed_tests, added_tests, diff_tests = get_diff()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -2,6 +2,7 @@ import shutil
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from distutils.dir_util import copy_tree
|
from distutils.dir_util import copy_tree
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import dominate
|
import dominate
|
||||||
import dominate.tags as t
|
import dominate.tags as t
|
||||||
@ -16,10 +17,12 @@ REPORTS_PATH = HERE / "reports" / "test"
|
|||||||
STYLE = (HERE / "testreport.css").read_text()
|
STYLE = (HERE / "testreport.css").read_text()
|
||||||
SCRIPT = (HERE / "testreport.js").read_text()
|
SCRIPT = (HERE / "testreport.js").read_text()
|
||||||
|
|
||||||
ACTUAL_HASHES = {}
|
ACTUAL_HASHES: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
def document(title, actual_hash=None, index=False):
|
def document(
|
||||||
|
title: str, actual_hash: str = None, index: bool = False
|
||||||
|
) -> dominate.document:
|
||||||
doc = dominate.document(title=title)
|
doc = dominate.document(title=title)
|
||||||
style = t.style()
|
style = t.style()
|
||||||
style.add_raw_string(STYLE)
|
style.add_raw_string(STYLE)
|
||||||
@ -36,7 +39,7 @@ def document(title, actual_hash=None, index=False):
|
|||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
||||||
def _header(test_name, expected_hash, actual_hash):
|
def _header(test_name: str, expected_hash: str, actual_hash: str) -> None:
|
||||||
h1(test_name)
|
h1(test_name)
|
||||||
with div():
|
with div():
|
||||||
if actual_hash == expected_hash:
|
if actual_hash == expected_hash:
|
||||||
@ -54,7 +57,7 @@ def _header(test_name, expected_hash, actual_hash):
|
|||||||
hr()
|
hr()
|
||||||
|
|
||||||
|
|
||||||
def clear_dir():
|
def clear_dir() -> None:
|
||||||
# delete and create the reports dir to clear previous entries
|
# delete and create the reports dir to clear previous entries
|
||||||
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
|
shutil.rmtree(REPORTS_PATH, ignore_errors=True)
|
||||||
REPORTS_PATH.mkdir()
|
REPORTS_PATH.mkdir()
|
||||||
@ -62,7 +65,7 @@ def clear_dir():
|
|||||||
(REPORTS_PATH / "passed").mkdir()
|
(REPORTS_PATH / "passed").mkdir()
|
||||||
|
|
||||||
|
|
||||||
def index():
|
def index() -> Path:
|
||||||
passed_tests = list((REPORTS_PATH / "passed").iterdir())
|
passed_tests = list((REPORTS_PATH / "passed").iterdir())
|
||||||
failed_tests = list((REPORTS_PATH / "failed").iterdir())
|
failed_tests = list((REPORTS_PATH / "failed").iterdir())
|
||||||
|
|
||||||
@ -104,7 +107,9 @@ def index():
|
|||||||
return html.write(REPORTS_PATH, doc, "index.html")
|
return html.write(REPORTS_PATH, doc, "index.html")
|
||||||
|
|
||||||
|
|
||||||
def failed(fixture_test_path, test_name, actual_hash, expected_hash):
|
def failed(
|
||||||
|
fixture_test_path: Path, test_name: str, actual_hash: str, expected_hash: str
|
||||||
|
) -> Path:
|
||||||
ACTUAL_HASHES[test_name] = actual_hash
|
ACTUAL_HASHES[test_name] = actual_hash
|
||||||
|
|
||||||
doc = document(title=test_name, actual_hash=actual_hash)
|
doc = document(title=test_name, actual_hash=actual_hash)
|
||||||
@ -147,7 +152,7 @@ def failed(fixture_test_path, test_name, actual_hash, expected_hash):
|
|||||||
return html.write(REPORTS_PATH / "failed", doc, test_name + ".html")
|
return html.write(REPORTS_PATH / "failed", doc, test_name + ".html")
|
||||||
|
|
||||||
|
|
||||||
def passed(fixture_test_path, test_name, actual_hash):
|
def passed(fixture_test_path: Path, test_name: str, actual_hash: str) -> Path:
|
||||||
copy_tree(str(fixture_test_path / "actual"), str(fixture_test_path / "recorded"))
|
copy_tree(str(fixture_test_path / "actual"), str(fixture_test_path / "recorded"))
|
||||||
|
|
||||||
doc = document(title=test_name)
|
doc = document(title=test_name)
|
||||||
|
@ -15,8 +15,10 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from _pytest.mark.structures import MarkDecorator
|
||||||
|
|
||||||
from ..emulators import ALL_TAGS, LOCAL_BUILD_PATHS
|
from ..emulators import ALL_TAGS, LOCAL_BUILD_PATHS
|
||||||
|
|
||||||
@ -44,7 +46,11 @@ core_only = pytest.mark.skipif(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0, 0)):
|
def for_all(
|
||||||
|
*args,
|
||||||
|
legacy_minimum_version: Tuple[int, int, int] = (1, 0, 0),
|
||||||
|
core_minimum_version: Tuple[int, int, int] = (2, 0, 0)
|
||||||
|
) -> "MarkDecorator":
|
||||||
"""Parametrizing decorator for test cases.
|
"""Parametrizing decorator for test cases.
|
||||||
|
|
||||||
Usage example:
|
Usage example:
|
||||||
@ -98,7 +104,7 @@ def for_all(*args, legacy_minimum_version=(1, 0, 0), core_minimum_version=(2, 0,
|
|||||||
return pytest.mark.parametrize("gen, tag", all_params)
|
return pytest.mark.parametrize("gen, tag", all_params)
|
||||||
|
|
||||||
|
|
||||||
def for_tags(*args):
|
def for_tags(*args: Tuple[str, List[str]]) -> "MarkDecorator":
|
||||||
enabled_gens = SELECTED_GENS or ("core", "legacy")
|
enabled_gens = SELECTED_GENS or ("core", "legacy")
|
||||||
return pytest.mark.parametrize(
|
return pytest.mark.parametrize(
|
||||||
"gen, tags", [(gen, tags) for gen, tags in args if gen in enabled_gens]
|
"gen, tags", [(gen, tags) for gen, tags in args if gen in enabled_gens]
|
||||||
|
Loading…
Reference in New Issue
Block a user