diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 353766f1cc..ac04d7da7e 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -41,6 +41,7 @@ from .client import ( from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES from .messages import Capability, DebugWaitType +from .protobuf import MessageType from .tools import parse_path from .transport.session import Session, SessionV1 from .transport.thp.protocol_v1 import ProtocolV1 @@ -1387,7 +1388,7 @@ class TrezorClientDebugLink(TrezorClient): def send_passphrase( passphrase: str | None = None, on_device: bool | None = None - ) -> t.Any: + ) -> MessageType: msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) resp = session.call_raw(msg) if isinstance(resp, messages.Deprecated_PassphraseStateRequest): diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index ea01ccc780..7bdb94f2a5 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -163,15 +163,10 @@ class SessionV1(Session): messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) ) if isinstance(self.passphrase, str): - self.passphrase_callback = _send_passphrase + self.passphrase_callback = self.client.passphrase_callback self._id = resp.session_id -def _send_passphrase(session: Session, resp: t.Any) -> None: - assert isinstance(session.passphrase, str) - session.call(messages.PassphraseAck(passphrase=session.passphrase)) - - def _callback_button(session: Session, msg: t.Any) -> t.Any: print("Please confirm action on your Trezor device.") # TODO how to handle UI? return session.call(messages.ButtonAck()) diff --git a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py index 7616dae5e0..dd25fc1342 100644 --- a/tests/device_tests/reset_recovery/test_reset_slip39_basic.py +++ b/tests/device_tests/reset_recovery/test_reset_slip39_basic.py @@ -83,6 +83,7 @@ def test_reset_device_slip39_basic_256(session: Session): @pytest.mark.setup_client(uninitialized=True) +@pytest.mark.uninitialized_session def test_reset_entropy_check(session: Session): member_threshold = 3 @@ -103,21 +104,23 @@ def test_reset_entropy_check(session: Session): entropy_check_count=3, _get_entropy=MOCK_GET_ENTROPY, ) - # Generate the master secret locally. - internal_entropy = client.debug.state().reset_entropy + internal_entropy = session.client.debug.state().reset_entropy assert internal_entropy is not None secret = generate_entropy(strength, internal_entropy, EXTERNAL_ENTROPY) # Check that all combinations will result in the correct master secret. validate_mnemonics(IF.mnemonics, member_threshold, secret) + # Create a session with cache backing + session = session.client.get_session() + # Check that the device is properly initialized. - assert client.features.initialized is True - assert client.features.backup_availability == BackupAvailability.NotAvailable - assert client.features.pin_protection is False - assert client.features.passphrase_protection is False - assert client.features.backup_type is BackupType.Slip39_Basic_Extendable + assert session.features.initialized is True + assert session.features.backup_availability == BackupAvailability.NotAvailable + assert session.features.pin_protection is False + assert session.features.passphrase_protection is False + assert session.features.backup_type is BackupType.Slip39_Basic_Extendable # Check that the XPUBs are the same as those from the entropy check. for path, xpub in path_xpubs: diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index f38afce7af..a070ff93cb 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -79,6 +79,7 @@ def _get_xpub( @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_session_with_passphrase(client: Client): session = Session(client.get_session(passphrase="A")) @@ -108,6 +109,7 @@ def test_session_with_passphrase(client: Client): @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_multiple_sessions(client: Client): # start SESSIONS_STORED sessions session_ids = [] @@ -150,6 +152,7 @@ def test_multiple_sessions(client: Client): @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_multiple_passphrases(client: Client): # start a session session_a = Session(client.get_session(passphrase="A")) @@ -186,6 +189,7 @@ def test_multiple_passphrases(client: Client): @pytest.mark.slow @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_max_sessions_with_passphrases(client: Client): # for the following tests, we are using as many passphrases as there are available sessions assert len(XPUB_PASSPHRASES) == SESSIONS_STORED @@ -232,6 +236,7 @@ def test_max_sessions_with_passphrases(client: Client): ) # passphrase is prompted +@pytest.mark.protocol("protocol_v1") def test_session_enable_passphrase(client: Client): # Let's start the communication by calling Initialize. session = Session(client.get_session(passphrase="")) @@ -258,6 +263,7 @@ def test_session_enable_passphrase(client: Client): @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_passphrase_on_device(client: Client): # _init_session(client) session = client.get_session(passphrase="A") @@ -297,6 +303,7 @@ def test_passphrase_on_device(client: Client): @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_passphrase_always_on_device(client: Client): # Let's start the communication by calling Initialize. session = client.get_session() @@ -332,6 +339,7 @@ def test_passphrase_always_on_device(client: Client): @pytest.mark.models("legacy") @pytest.mark.setup_client(passphrase="") +@pytest.mark.protocol("protocol_v1") def test_passphrase_on_device_not_possible_on_t1(client: Client): # This setting makes no sense on T1. response = client.call_raw(messages.ApplySettings(passphrase_always_on_device=True)) @@ -347,6 +355,7 @@ def test_passphrase_on_device_not_possible_on_t1(client: Client): @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_passphrase_ack_mismatch(session: Session): response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) @@ -356,6 +365,7 @@ def test_passphrase_ack_mismatch(session: Session): @pytest.mark.setup_client(passphrase="") +@pytest.mark.protocol("protocol_v1") def test_passphrase_missing(session: Session): response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) @@ -373,6 +383,7 @@ def test_passphrase_missing(session: Session): @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_passphrase_length(client: Client): def call(passphrase: str, expected_result: bool): session = client.get_session(passphrase=passphrase) @@ -398,6 +409,7 @@ def test_passphrase_length(client: Client): @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_hide_passphrase_from_host(client: Client): # Without safety checks, turning it on fails session = client.get_management_session() @@ -495,6 +507,7 @@ def _get_xpub_cardano(session: Session, expected_passphrase_req: bool = False): @pytest.mark.models("core") @pytest.mark.altcoin @pytest.mark.setup_client(passphrase=True) +@pytest.mark.protocol("protocol_v1") def test_cardano_passphrase(client: Client): # Cardano has a separate derivation method that needs to access the plaintext # of the passphrase.