1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-04 07:59:06 +00:00

chore(tests): improve stability of tests

This commit is contained in:
M1nd3r 2025-04-12 00:54:08 +02:00
parent 46269ef935
commit cc7abfae8f
5 changed files with 47 additions and 13 deletions

View File

@ -40,6 +40,7 @@ from .tools import parse_path
from .transport import Timeout from .transport import Timeout
from .transport.session import ProtocolV2Channel, Session from .transport.session import ProtocolV2Channel, Session
from .transport.thp.protocol_v1 import ProtocolV1Channel from .transport.thp.protocol_v1 import ProtocolV1Channel
from trezorlib import exceptions
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
@ -1031,6 +1032,7 @@ class TrezorClientDebugLink(TrezorClient):
# by the device. # by the device.
protocol: ProtocolV1Channel | ProtocolV2Channel protocol: ProtocolV1Channel | ProtocolV2Channel
actual_responses: list[protobuf.MessageType] | None = None
def __init__( def __init__(
self, self,
@ -1056,8 +1058,6 @@ class TrezorClientDebugLink(TrezorClient):
transport.open() transport.open()
# set transport explicitly so that sync_responses can work # set transport explicitly so that sync_responses can work
super().__init__(transport)
self.transport = transport self.transport = transport
self.ui: DebugUI = DebugUI(self.debug) self.ui: DebugUI = DebugUI(self.debug)
@ -1071,6 +1071,13 @@ class TrezorClientDebugLink(TrezorClient):
self.pin_callback = get_pin self.pin_callback = get_pin
self.button_callback = self.ui.button_request self.button_callback = self.ui.button_request
try:
super().__init__(transport)
except exceptions.DeviceLockedException:
self.use_pin_sequence(["1234"])
self.debug.input(self.debug.encode_pin("1234"))
super().__init__(transport)
self.sync_responses() self.sync_responses()
# So that we can choose right screenshotting logic (T1 vs TT) # So that we can choose right screenshotting logic (T1 vs TT)
@ -1303,6 +1310,7 @@ class TrezorClientDebugLink(TrezorClient):
# Propagate the exception through the input flow, so that we see in # Propagate the exception through the input flow, so that we see in
# traceback where it is stuck. # traceback where it is stuck.
input_flow.throw(exc_type, value, traceback) input_flow.throw(exc_type, value, traceback)
self.actual_responses = None
@classmethod @classmethod
def _verify_responses( def _verify_responses(
@ -1406,6 +1414,14 @@ class TrezorClientDebugLink(TrezorClient):
self.actual_responses.append(resp) self.actual_responses.append(resp)
return resp return resp
def notify_read(self, msg: protobuf.MessageType) -> None:
pass
try:
if self.actual_responses is not None:
self.actual_responses.append(msg)
except Exception as e:
print(e)
def load_device( def load_device(
session: "Session", session: "Session",

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import typing as t import typing as t
from .. import exceptions, messages, models from .. import exceptions, messages, models
from ..client import MAX_PIN_LENGTH from ..client import MAX_PIN_LENGTH
from ..protobuf import MessageType from ..protobuf import MessageType
@ -274,8 +275,12 @@ class SessionV2(Session):
self.channel.write(self.sid, msg) self.channel.write(self.sid, msg)
def _read(self) -> t.Any: def _read(self) -> t.Any:
from trezorlib.debuglink import TrezorClientDebugLink
msg = self.channel.read(self.sid) msg = self.channel.read(self.sid)
LOG.debug("reading message %s", type(msg)) LOG.debug("reading message %s", type(msg))
if isinstance(self.client, TrezorClientDebugLink):
self.client.notify_read(msg)
return msg return msg
def update_id_and_sid(self, id: bytes) -> None: def update_id_and_sid(self, id: bytes) -> None:

View File

@ -345,6 +345,10 @@ def _client_unlocked(
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
LOG.error(f"Failed to re-create a client: {e}") LOG.error(f"Failed to re-create a client: {e}")
sleep(LOCK_TIME) sleep(LOCK_TIME)
try:
_raw_client = _raw_client.get_new_client()
except Exception as e:
sleep(1.5)
_raw_client = _get_raw_client(request) _raw_client = _get_raw_client(request)
session = _raw_client.get_seedless_session() session = _raw_client.get_seedless_session()

View File

@ -58,6 +58,7 @@ def test_cancel_message_via_cancel(session: Session, message):
), ),
], ],
) )
@pytest.mark.protocol("protocol_v1")
def test_cancel_message_via_initialize(session: Session, message): def test_cancel_message_via_initialize(session: Session, message):
resp = session.call_raw(message) resp = session.call_raw(message)
assert isinstance(resp, m.ButtonRequest) assert isinstance(resp, m.ButtonRequest)

View File

@ -19,6 +19,7 @@ from pathlib import Path
import pytest import pytest
from trezorlib import btc, device, exceptions, messages, misc, models from trezorlib import btc, device, exceptions, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
@ -211,7 +212,8 @@ def test_apply_homescreen_toif(session: Session):
def test_apply_homescreen_jpeg(session: Session): def test_apply_homescreen_jpeg(session: Session):
with open(HERE / "test_bg.jpg", "rb") as f: with open(HERE / "test_bg.jpg", "rb") as f:
img = f.read() img = f.read()
# raise Exception("FAILS FOR SOME REASON ") if session.protocol_version is ProtocolVersion.V2:
raise Exception("Message too large for THP")
with session.client as client: with session.client as client:
_set_expected_responses(client) _set_expected_responses(client)
device.apply_settings(session, homescreen=img) device.apply_settings(session, homescreen=img)
@ -338,7 +340,9 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client: with pytest.raises(
exceptions.TrezorFailure, match="Forbidden key path"
), session.client as client:
client.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
get_bad_address() get_bad_address()
@ -351,11 +355,11 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways
with client: with session.client as client:
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address] [messages.ButtonRequest, messages.ButtonRequest, messages.Address]
) )
IF = InputFlowConfirmAllWarnings(session.client) IF = InputFlowConfirmAllWarnings(client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
get_bad_address() get_bad_address()
@ -365,11 +369,13 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client: with pytest.raises(
exceptions.TrezorFailure, match="Forbidden key path"
), session.client as client:
client.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
get_bad_address() get_bad_address()
with client: with session.client as client:
client.set_expected_responses(EXPECTED_RESPONSES_NOPIN) client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
device.apply_settings( device.apply_settings(
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
@ -377,11 +383,11 @@ def test_safety_checks(session: Session):
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily
with client: with session.client as client:
client.set_expected_responses( client.set_expected_responses(
[messages.ButtonRequest, messages.ButtonRequest, messages.Address] [messages.ButtonRequest, messages.ButtonRequest, messages.Address]
) )
if session.model is not models.T1B1: if client.model is not models.T1B1:
IF = InputFlowConfirmAllWarnings(session.client) IF = InputFlowConfirmAllWarnings(session.client)
client.set_input_flow(IF.get()) client.set_input_flow(IF.get())
get_bad_address() get_bad_address()
@ -403,7 +409,9 @@ def test_experimental_features(session: Session):
assert not session.features.experimental_features assert not session.features.experimental_features
with pytest.raises(exceptions.TrezorFailure, match="DataError"), client: with pytest.raises(
exceptions.TrezorFailure, match="DataError"
), session.client as client:
client.set_expected_responses([messages.Failure]) client.set_expected_responses([messages.Failure])
experimental_call() experimental_call()
@ -413,7 +421,7 @@ def test_experimental_features(session: Session):
assert session.features.experimental_features assert session.features.experimental_features
with client: with session.client as client:
client.set_expected_responses([messages.Nonce]) client.set_expected_responses([messages.Nonce])
experimental_call() experimental_call()