1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-20 17:19:01 +00:00

chore(tests): improve stability of tests

This commit is contained in:
M1nd3r 2025-04-12 00:54:08 +02:00
parent 34b87b7692
commit 5d755c82ff
5 changed files with 47 additions and 13 deletions

View File

@ -40,6 +40,7 @@ from .tools import parse_path
from .transport import Timeout
from .transport.session import ProtocolV2Channel, Session
from .transport.thp.protocol_v1 import ProtocolV1Channel
from trezorlib import exceptions
if t.TYPE_CHECKING:
from typing_extensions import Protocol
@ -1060,6 +1061,7 @@ class TrezorClientDebugLink(TrezorClient):
# by the device.
protocol: ProtocolV1Channel | ProtocolV2Channel
actual_responses: list[protobuf.MessageType] | None = None
def __init__(
self,
@ -1085,8 +1087,6 @@ class TrezorClientDebugLink(TrezorClient):
transport.open()
# set transport explicitly so that sync_responses can work
super().__init__(transport)
self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
@ -1100,6 +1100,13 @@ class TrezorClientDebugLink(TrezorClient):
self.pin_callback = get_pin
self.button_callback = self.ui.button_request
try:
super().__init__(transport)
except exceptions.DeviceLockedException:
self.use_pin_sequence(["1234"])
self.debug.input(self.debug.encode_pin("1234"))
super().__init__(transport)
self.sync_responses()
# So that we can choose right screenshotting logic (T1 vs TT)
@ -1332,6 +1339,7 @@ class TrezorClientDebugLink(TrezorClient):
# Propagate the exception through the input flow, so that we see in
# traceback where it is stuck.
input_flow.throw(exc_type, value, traceback)
self.actual_responses = None
@classmethod
def _verify_responses(
@ -1435,6 +1443,14 @@ class TrezorClientDebugLink(TrezorClient):
self.actual_responses.append(resp)
return resp
def notify_read(self, msg: protobuf.MessageType) -> None:
pass
try:
if self.actual_responses is not None:
self.actual_responses.append(msg)
except Exception as e:
print(e)
def load_device(
session: "Session",

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import typing as t
from .. import exceptions, messages, models
from ..client import MAX_PIN_LENGTH
from ..protobuf import MessageType
@ -274,8 +275,12 @@ class SessionV2(Session):
self.channel.write(self.sid, msg)
def _read(self) -> t.Any:
from trezorlib.debuglink import TrezorClientDebugLink
msg = self.channel.read(self.sid)
LOG.debug("reading message %s", type(msg))
if isinstance(self.client, TrezorClientDebugLink):
self.client.notify_read(msg)
return msg
def update_id_and_sid(self, id: bytes) -> None:

View File

@ -346,7 +346,11 @@ def _client_unlocked(
LOG = logging.getLogger(__name__)
LOG.error(f"Failed to re-create a client: {e}")
sleep(LOCK_TIME)
_raw_client = _get_raw_client(request)
try:
_raw_client = _raw_client.get_new_client()
except Exception as e:
sleep(1.5)
_raw_client = _get_raw_client(request)
session = _raw_client.get_seedless_session()
wipe_device(session)

View File

@ -58,6 +58,7 @@ def test_cancel_message_via_cancel(session: Session, message):
),
],
)
@pytest.mark.protocol("protocol_v1")
def test_cancel_message_via_initialize(session: Session, message):
resp = session.call_raw(message)
assert isinstance(resp, m.ButtonRequest)

View File

@ -19,6 +19,7 @@ from pathlib import Path
import pytest
from trezorlib import btc, device, exceptions, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path
@ -211,7 +212,8 @@ def test_apply_homescreen_toif(session: Session):
def test_apply_homescreen_jpeg(session: Session):
with open(HERE / "test_bg.jpg", "rb") as f:
img = f.read()
# raise Exception("FAILS FOR SOME REASON ")
if session.protocol_version is ProtocolVersion.V2:
raise Exception("Message too large for THP")
with session.client as client:
_set_expected_responses(client)
device.apply_settings(session, homescreen=img)
@ -339,7 +341,9 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client:
with pytest.raises(
exceptions.TrezorFailure, match="Forbidden key path"
), session.client as client:
client.set_expected_responses([messages.Failure])
get_bad_address()
@ -352,11 +356,11 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways
with client:
with session.client as client:
client.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
)
IF = InputFlowConfirmAllWarnings(session.client)
IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get())
get_bad_address()
@ -366,11 +370,13 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client:
with pytest.raises(
exceptions.TrezorFailure, match="Forbidden key path"
), session.client as client:
client.set_expected_responses([messages.Failure])
get_bad_address()
with client:
with session.client as client:
client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings(
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
@ -378,11 +384,11 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily
with client:
with session.client as client:
client.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
)
if session.model is not models.T1B1:
if client.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get())
get_bad_address()
@ -404,7 +410,9 @@ def test_experimental_features(session: Session):
assert not session.features.experimental_features
with pytest.raises(exceptions.TrezorFailure, match="DataError"), client:
with pytest.raises(
exceptions.TrezorFailure, match="DataError"
), session.client as client:
client.set_expected_responses([messages.Failure])
experimental_call()
@ -414,7 +422,7 @@ def test_experimental_features(session: Session):
assert session.features.experimental_features
with client:
with session.client as client:
client.set_expected_responses([messages.Nonce])
experimental_call()