mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-03 23:49:02 +00:00
chore(tests): improve stability of tests
This commit is contained in:
parent
46269ef935
commit
cc7abfae8f
@ -40,6 +40,7 @@ from .tools import parse_path
|
|||||||
from .transport import Timeout
|
from .transport import Timeout
|
||||||
from .transport.session import ProtocolV2Channel, Session
|
from .transport.session import ProtocolV2Channel, Session
|
||||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||||
|
from trezorlib import exceptions
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
@ -1031,6 +1032,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# by the device.
|
# by the device.
|
||||||
|
|
||||||
protocol: ProtocolV1Channel | ProtocolV2Channel
|
protocol: ProtocolV1Channel | ProtocolV2Channel
|
||||||
|
actual_responses: list[protobuf.MessageType] | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1056,8 +1058,6 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
transport.open()
|
transport.open()
|
||||||
|
|
||||||
# set transport explicitly so that sync_responses can work
|
# set transport explicitly so that sync_responses can work
|
||||||
super().__init__(transport)
|
|
||||||
|
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.ui: DebugUI = DebugUI(self.debug)
|
self.ui: DebugUI = DebugUI(self.debug)
|
||||||
|
|
||||||
@ -1071,6 +1071,13 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.pin_callback = get_pin
|
self.pin_callback = get_pin
|
||||||
self.button_callback = self.ui.button_request
|
self.button_callback = self.ui.button_request
|
||||||
|
|
||||||
|
try:
|
||||||
|
super().__init__(transport)
|
||||||
|
except exceptions.DeviceLockedException:
|
||||||
|
self.use_pin_sequence(["1234"])
|
||||||
|
self.debug.input(self.debug.encode_pin("1234"))
|
||||||
|
super().__init__(transport)
|
||||||
|
|
||||||
self.sync_responses()
|
self.sync_responses()
|
||||||
|
|
||||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
# So that we can choose right screenshotting logic (T1 vs TT)
|
||||||
@ -1303,6 +1310,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# Propagate the exception through the input flow, so that we see in
|
# Propagate the exception through the input flow, so that we see in
|
||||||
# traceback where it is stuck.
|
# traceback where it is stuck.
|
||||||
input_flow.throw(exc_type, value, traceback)
|
input_flow.throw(exc_type, value, traceback)
|
||||||
|
self.actual_responses = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _verify_responses(
|
def _verify_responses(
|
||||||
@ -1406,6 +1414,14 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.actual_responses.append(resp)
|
self.actual_responses.append(resp)
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
def notify_read(self, msg: protobuf.MessageType) -> None:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
if self.actual_responses is not None:
|
||||||
|
self.actual_responses.append(msg)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
def load_device(
|
def load_device(
|
||||||
session: "Session",
|
session: "Session",
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
|
||||||
from .. import exceptions, messages, models
|
from .. import exceptions, messages, models
|
||||||
from ..client import MAX_PIN_LENGTH
|
from ..client import MAX_PIN_LENGTH
|
||||||
from ..protobuf import MessageType
|
from ..protobuf import MessageType
|
||||||
@ -274,8 +275,12 @@ class SessionV2(Session):
|
|||||||
self.channel.write(self.sid, msg)
|
self.channel.write(self.sid, msg)
|
||||||
|
|
||||||
def _read(self) -> t.Any:
|
def _read(self) -> t.Any:
|
||||||
|
from trezorlib.debuglink import TrezorClientDebugLink
|
||||||
|
|
||||||
msg = self.channel.read(self.sid)
|
msg = self.channel.read(self.sid)
|
||||||
LOG.debug("reading message %s", type(msg))
|
LOG.debug("reading message %s", type(msg))
|
||||||
|
if isinstance(self.client, TrezorClientDebugLink):
|
||||||
|
self.client.notify_read(msg)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def update_id_and_sid(self, id: bytes) -> None:
|
def update_id_and_sid(self, id: bytes) -> None:
|
||||||
|
@ -345,7 +345,11 @@ def _client_unlocked(
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
LOG.error(f"Failed to re-create a client: {e}")
|
LOG.error(f"Failed to re-create a client: {e}")
|
||||||
sleep(LOCK_TIME)
|
sleep(LOCK_TIME)
|
||||||
_raw_client = _get_raw_client(request)
|
try:
|
||||||
|
_raw_client = _raw_client.get_new_client()
|
||||||
|
except Exception as e:
|
||||||
|
sleep(1.5)
|
||||||
|
_raw_client = _get_raw_client(request)
|
||||||
|
|
||||||
session = _raw_client.get_seedless_session()
|
session = _raw_client.get_seedless_session()
|
||||||
wipe_device(session)
|
wipe_device(session)
|
||||||
|
@ -58,6 +58,7 @@ def test_cancel_message_via_cancel(session: Session, message):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.protocol("protocol_v1")
|
||||||
def test_cancel_message_via_initialize(session: Session, message):
|
def test_cancel_message_via_initialize(session: Session, message):
|
||||||
resp = session.call_raw(message)
|
resp = session.call_raw(message)
|
||||||
assert isinstance(resp, m.ButtonRequest)
|
assert isinstance(resp, m.ButtonRequest)
|
||||||
|
@ -19,6 +19,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from trezorlib import btc, device, exceptions, messages, misc, models
|
from trezorlib import btc, device, exceptions, messages, misc, models
|
||||||
|
from trezorlib.client import ProtocolVersion
|
||||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
from trezorlib.tools import parse_path
|
from trezorlib.tools import parse_path
|
||||||
@ -211,7 +212,8 @@ def test_apply_homescreen_toif(session: Session):
|
|||||||
def test_apply_homescreen_jpeg(session: Session):
|
def test_apply_homescreen_jpeg(session: Session):
|
||||||
with open(HERE / "test_bg.jpg", "rb") as f:
|
with open(HERE / "test_bg.jpg", "rb") as f:
|
||||||
img = f.read()
|
img = f.read()
|
||||||
# raise Exception("FAILS FOR SOME REASON ")
|
if session.protocol_version is ProtocolVersion.V2:
|
||||||
|
raise Exception("Message too large for THP")
|
||||||
with session.client as client:
|
with session.client as client:
|
||||||
_set_expected_responses(client)
|
_set_expected_responses(client)
|
||||||
device.apply_settings(session, homescreen=img)
|
device.apply_settings(session, homescreen=img)
|
||||||
@ -338,7 +340,9 @@ def test_safety_checks(session: Session):
|
|||||||
|
|
||||||
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
|
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
|
||||||
|
|
||||||
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client:
|
with pytest.raises(
|
||||||
|
exceptions.TrezorFailure, match="Forbidden key path"
|
||||||
|
), session.client as client:
|
||||||
client.set_expected_responses([messages.Failure])
|
client.set_expected_responses([messages.Failure])
|
||||||
get_bad_address()
|
get_bad_address()
|
||||||
|
|
||||||
@ -351,11 +355,11 @@ def test_safety_checks(session: Session):
|
|||||||
|
|
||||||
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways
|
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways
|
||||||
|
|
||||||
with client:
|
with session.client as client:
|
||||||
client.set_expected_responses(
|
client.set_expected_responses(
|
||||||
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
|
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
|
||||||
)
|
)
|
||||||
IF = InputFlowConfirmAllWarnings(session.client)
|
IF = InputFlowConfirmAllWarnings(client)
|
||||||
client.set_input_flow(IF.get())
|
client.set_input_flow(IF.get())
|
||||||
get_bad_address()
|
get_bad_address()
|
||||||
|
|
||||||
@ -365,11 +369,13 @@ def test_safety_checks(session: Session):
|
|||||||
|
|
||||||
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
|
assert session.features.safety_checks == messages.SafetyCheckLevel.Strict
|
||||||
|
|
||||||
with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client:
|
with pytest.raises(
|
||||||
|
exceptions.TrezorFailure, match="Forbidden key path"
|
||||||
|
), session.client as client:
|
||||||
client.set_expected_responses([messages.Failure])
|
client.set_expected_responses([messages.Failure])
|
||||||
get_bad_address()
|
get_bad_address()
|
||||||
|
|
||||||
with client:
|
with session.client as client:
|
||||||
client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
|
client.set_expected_responses(EXPECTED_RESPONSES_NOPIN)
|
||||||
device.apply_settings(
|
device.apply_settings(
|
||||||
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
|
session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily
|
||||||
@ -377,11 +383,11 @@ def test_safety_checks(session: Session):
|
|||||||
|
|
||||||
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily
|
assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily
|
||||||
|
|
||||||
with client:
|
with session.client as client:
|
||||||
client.set_expected_responses(
|
client.set_expected_responses(
|
||||||
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
|
[messages.ButtonRequest, messages.ButtonRequest, messages.Address]
|
||||||
)
|
)
|
||||||
if session.model is not models.T1B1:
|
if client.model is not models.T1B1:
|
||||||
IF = InputFlowConfirmAllWarnings(session.client)
|
IF = InputFlowConfirmAllWarnings(session.client)
|
||||||
client.set_input_flow(IF.get())
|
client.set_input_flow(IF.get())
|
||||||
get_bad_address()
|
get_bad_address()
|
||||||
@ -403,7 +409,9 @@ def test_experimental_features(session: Session):
|
|||||||
|
|
||||||
assert not session.features.experimental_features
|
assert not session.features.experimental_features
|
||||||
|
|
||||||
with pytest.raises(exceptions.TrezorFailure, match="DataError"), client:
|
with pytest.raises(
|
||||||
|
exceptions.TrezorFailure, match="DataError"
|
||||||
|
), session.client as client:
|
||||||
client.set_expected_responses([messages.Failure])
|
client.set_expected_responses([messages.Failure])
|
||||||
experimental_call()
|
experimental_call()
|
||||||
|
|
||||||
@ -413,7 +421,7 @@ def test_experimental_features(session: Session):
|
|||||||
|
|
||||||
assert session.features.experimental_features
|
assert session.features.experimental_features
|
||||||
|
|
||||||
with client:
|
with session.client as client:
|
||||||
client.set_expected_responses([messages.Nonce])
|
client.set_expected_responses([messages.Nonce])
|
||||||
experimental_call()
|
experimental_call()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user