From c1ec1d38bacac079db91bcc8d52e3b380e218f81 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 26 Mar 2025 16:50:28 +0100 Subject: [PATCH] fixup! test: update device tests --- tests/device_tests/test_session.py | 86 +++++---- .../test_session_id_and_passphrase.py | 166 +++++++++++------- 2 files changed, 155 insertions(+), 97 deletions(-) diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 34ae585080..a3ce7b4de3 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -16,11 +16,11 @@ import pytest -from trezorlib import cardano, messages, models -from trezorlib.btc import get_public_node +from trezorlib import cardano, exceptions, messages, models from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure -from trezorlib.tools import parse_path +from trezorlib.tools import Address, parse_path +from trezorlib.transport.session import Session, SessionV1 from ..common import get_test_address @@ -30,6 +30,22 @@ XPUB = "xpub6BiVtCpG9fQPxnPmHXG8PhtzQdWC2Su4qWu6XW9tpWFYhxydCLJGrWBJZ5H6qTAHdPQ7 PIN4 = "1234" +def _get_public_node( + session: "Session", + address: "Address", + passphrase: str | None = None, +) -> messages.PublicKey: + + resp = session.call_raw( + messages.GetPublicKey(address_n=address), + ) + if isinstance(resp, messages.ButtonRequest): + resp = session._callback_button(resp) + if passphrase is not None: + resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) + return resp + + @pytest.mark.setup_client(pin=PIN4, passphrase="") def test_clear_session(client: Client): is_t1 = client.model is models.T1B1 @@ -44,13 +60,13 @@ def test_clear_session(client: Client): with session: client.use_pin_sequence([PIN4]) session.set_expected_responses(init_responses + cached_responses) - assert get_public_node(session, ADDRESS_N).xpub == XPUB + assert _get_public_node(session, ADDRESS_N, passphrase="").xpub == XPUB session.resume() with session: # pin and passphrase are cached session.set_expected_responses(cached_responses) - assert get_public_node(session, ADDRESS_N).xpub == XPUB + assert _get_public_node(session, ADDRESS_N).xpub == XPUB session.lock() session.end() @@ -60,13 +76,13 @@ def test_clear_session(client: Client): with session: client.use_pin_sequence([PIN4]) session.set_expected_responses(init_responses + cached_responses) - assert get_public_node(session, ADDRESS_N).xpub == XPUB + assert _get_public_node(session, ADDRESS_N, passphrase="").xpub == XPUB session.resume() with session: # pin and passphrase are cached session.set_expected_responses(cached_responses) - assert get_public_node(session, ADDRESS_N).xpub == XPUB + assert _get_public_node(session, ADDRESS_N).xpub == XPUB def test_end_session(client: Client): @@ -109,9 +125,10 @@ def test_cannot_resume_ended_session(client: Client): assert session.id == session_id session.end() - session.resume() + with pytest.raises(exceptions.FailedSessionResumption) as e: + session.resume() - assert session.id != session_id + assert e.value.received_session_id != session_id def test_end_session_only_current(client: Client): @@ -124,8 +141,10 @@ def test_end_session_only_current(client: Client): # assert client.session_id is None # resume ended session - session_b.resume() - assert session_b.id != session_b_id + with pytest.raises(exceptions.FailedSessionResumption) as e: + session_b.resume() + + assert e.value.received_session_id != session_b_id # resume first session that was not ended session_a.resume() @@ -136,14 +155,7 @@ def test_end_session_only_current(client: Client): def test_session_recycling(client: Client): session = client.get_session(passphrase="TREZOR") with session: - session.set_expected_responses( - [ - messages.PassphraseRequest, - messages.ButtonRequest, - messages.ButtonRequest, - messages.Address, - ] - ) + session.set_expected_responses([messages.Address]) address = get_test_address(session) # create and close 100 sessions - more than the session limit @@ -174,10 +186,10 @@ def test_derive_cardano_empty_session(client: Client): assert session.id == session_id # restarting same session should go well with any setting - session_3 = client.get_session(session_id=session_id, derive_cardano=False) - assert session_id == session_3.id - session_4 = client.get_session(session_id=session_id, derive_cardano=True) - assert session_id == session_4.id + session.init_session(derive_cardano=False) + assert session_id == session.id + session.init_session(derive_cardano=True) + assert session_id == session.id @pytest.mark.altcoin @@ -199,24 +211,30 @@ def test_derive_cardano_running_session(client: Client): assert session.id == session_id # restarting same session should go well if we _don't_ want to derive cardano - session_3 = client.get_session(session_id=session_id, derive_cardano=False) - assert session_3.id == session.id + session.init_session(derive_cardano=False) + assert session.id == session_id # restarting with derive_cardano=True should kill old session and create new one - session_4 = client.get_session(derive_cardano=True) - session_4_id = session_4.id - assert session_4_id != session.id + with pytest.raises(exceptions.FailedSessionResumption) as e: + session.init_session(derive_cardano=True) + session_2 = SessionV1(client, e.value.received_session_id) + session_2.derive_cardano = True + session_2_id = session_2.id + assert session_2_id != session.id # new session should have Cardano capability - cardano.get_public_key(session_4, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session_2, parse_path("m/44h/1815h/0h")) # restarting with derive_cardano=True should keep same session - session_4.resume() - assert session_4.id == session_4_id + session_2.resume() + assert session_2.id == session_2_id # restarting with derive_cardano=False should kill old session and create new one - session_6 = client.get_session(session_id=session_4.id, derive_cardano=False) - assert session_4.id != session_6.id + with pytest.raises(exceptions.FailedSessionResumption) as e: + session_2.init_session(derive_cardano=False) + session_3 = SessionV1(client, e.value.received_session_id) + + assert session_2.id != session_3.id with pytest.raises(TrezorFailure, match="not enabled"): - cardano.get_public_key(session_6, parse_path("m/44h/1815h/0h")) + cardano.get_public_key(session_3, parse_path("m/44h/1815h/0h")) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 8b7fd0b281..9049bd8643 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -25,6 +25,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure from trezorlib.messages import FailureType, SafetyCheckLevel from trezorlib.tools import parse_path +from trezorlib.transport.session import SessionV1 from .. import translations as TR @@ -52,10 +53,10 @@ SESSIONS_STORED = 10 def _get_xpub( session: Session, - expected_passphrase_req: bool = False, + passphrase: str | None = None, ): """Get XPUB and check that the appropriate passphrase flow has happened.""" - if expected_passphrase_req: + if passphrase is not None: expected_responses = [ messages.PassphraseRequest, messages.ButtonRequest, @@ -67,18 +68,35 @@ def _get_xpub( with session: session.set_expected_responses(expected_responses) - result = session.call(XPUB_REQUEST) + result = session.call_raw(XPUB_REQUEST) + if passphrase is not None: + result = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) + while isinstance(result, messages.ButtonRequest): + result = session._callback_button(result) return result.xpub +def _get_session(client: Client, session_id=None, derive_cardano=False) -> Session: + """Call Initialize, check and return the session.""" + + from trezorlib.transport.session import SessionV1 + + session = SessionV1.new( + client=client, derive_cardano=derive_cardano, session_id=session_id + ) + return Session(session) + + @pytest.mark.setup_client(passphrase=True) def test_session_with_passphrase(client: Client): - session = client.get_session(passphrase="A") + # session = client.get_session(passphrase="A") + session = _get_session(client) session_id = session.id + # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] + assert _get_xpub(session, passphrase="A") == XPUB_PASSPHRASES["A"] # Call Initialize again, this time with the received session id and then call # GetPublicKey. The passphrase should be cached now so Trezor must @@ -89,30 +107,42 @@ def test_session_with_passphrase(client: Client): # If we set session id in Initialize to None, the cache will be cleared # and Trezor will ask for the passphrase again. - session3 = client.get_session(passphrase="A") - assert session3.id != session_id - assert _get_xpub(session3, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] + session_2 = _get_session(client) + assert session_2.id != session_id + assert _get_xpub(session_2, passphrase="A") == XPUB_PASSPHRASES["A"] - # Unknown session id has the same result as setting it to None. - session4 = client.get_session(session_id=b"X" * 32, passphrase="A") - assert session4.id != b"X" * 32 - assert session4.id != session_id - assert session4.id != session3.id - assert _get_xpub(session4, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] + # Unknown session id leads to FailedSessionResumption in trezorlib. + # Trezor ignores the invalid session_id and creates a new session + with pytest.raises(exceptions.FailedSessionResumption) as e: + _get_session(client, session_id=b"X" * 32) + + session_3 = Session(SessionV1(client, e.value.received_session_id)) + + assert session_3.id is not None + assert len(session_3.id) == 32 + assert session_3.id != b"X" * 32 + assert session_3.id != session_id + assert session_3.id != session_2.id + assert _get_xpub(session_3, passphrase="A") == XPUB_PASSPHRASES["A"] @pytest.mark.setup_client(passphrase=True) def test_multiple_sessions(client: Client): # start SESSIONS_STORED sessions + SESSIONS_STORED = 10 session_ids = [] sessions = [] for _ in range(SESSIONS_STORED): - session = client.get_session() + session = _get_session(client) sessions.append(session) session_ids.append(session.id) # Resume each session for i in range(SESSIONS_STORED): + if i == 0: + pass + # raise Exception(sessions[i]._session.id) + sessions[i].resume() assert session_ids[i] == sessions[i].id @@ -125,8 +155,9 @@ def test_multiple_sessions(client: Client): assert session_ids[i] == sessions[i].id # Resuming session 0 will not work - sessions[0].resume() - assert session_ids[0] != sessions[0].id + with pytest.raises(exceptions.FailedSessionResumption) as e: + sessions[0].resume() + assert session_ids[0] != e.value.received_session_id # New session bumped out the least-recently-used anonymous session. # Resuming session 1 through SESSIONS_STORED will still work @@ -135,20 +166,21 @@ def test_multiple_sessions(client: Client): assert session_ids[i] == sessions[i].id # Creating a new session replaces session_ids[0] again - client.get_session() + _get_session(client) # Resuming all sessions one by one will in turn bump out the previous session. for i in range(SESSIONS_STORED): - sessions[i].resume() - assert session_ids[i] != sessions[i].id + with pytest.raises(exceptions.FailedSessionResumption) as e: + sessions[i].resume() + assert session_ids[i] != e.value.received_session_id @pytest.mark.setup_client(passphrase=True) def test_multiple_passphrases(client: Client): # start a session - session_a = client.get_session(passphrase="A") + session_a = _get_session(client) session_a_id = session_a.id - assert _get_xpub(session_a, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] + assert _get_xpub(session_a, passphrase="A") == XPUB_PASSPHRASES["A"] # start it again wit the same session id session_a.resume() # session is the same @@ -157,10 +189,10 @@ def test_multiple_passphrases(client: Client): assert _get_xpub(session_a) == XPUB_PASSPHRASES["A"] # start a second session - session_b = client.get_session(passphrase="B") + session_b = _get_session(client) session_b_id = session_b.id # new session -> new session id and passphrase prompt - assert _get_xpub(session_b, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] + assert _get_xpub(session_b, passphrase="B") == XPUB_PASSPHRASES["B"] # provide the same session id -> must not ask for passphrase again. session_b.resume() @@ -188,11 +220,11 @@ def test_max_sessions_with_passphrases(client: Client): session_ids = {} sessions = {} for passphrase, xpub in XPUB_PASSPHRASES.items(): - session = client.get_session(passphrase=passphrase) + session = _get_session(client) assert session.id not in session_ids.values() session_ids[passphrase] = session.id sessions[passphrase] = session - assert _get_xpub(session, expected_passphrase_req=True) == xpub + assert _get_xpub(session, passphrase=passphrase) == xpub # passphrase is not prompted for the started the sessions, regardless the order # let's try 20 different orderings @@ -212,25 +244,24 @@ def test_max_sessions_with_passphrases(client: Client): assert _get_xpub(sessions[passphrase]) == XPUB_PASSPHRASES[passphrase] # creating one more session will exceed the limit - new_session = client.get_session(passphrase="XX") + new_session = _get_session(client) # new session asks for passphrase - _get_xpub(new_session, expected_passphrase_req=True) + _get_xpub(new_session, passphrase="XX") # restoring the sessions in reverse will evict the next-up session for passphrase in reversed(passphrases): - sessions[passphrase].resume() - _get_xpub( - sessions[passphrase], - expected_passphrase_req=True, - ) # passphrase is prompted + with pytest.raises(exceptions.FailedSessionResumption) as e: + sessions[passphrase].resume() + sessions[passphrase] = Session(SessionV1(client, e.value.received_session_id)) + _get_xpub(sessions[passphrase], passphrase=passphrase) # passphrase is prompted def test_session_enable_passphrase(client: Client): # Let's start the communication by calling Initialize. - session = client.get_session(passphrase="") + session = _get_session(client) # Trezor will not prompt for passphrase because it is turned off. - assert _get_xpub(session, expected_passphrase_req=False) == XPUB_PASSPHRASE_NONE + assert _get_xpub(session) == XPUB_PASSPHRASE_NONE # Turn on passphrase. # Emit the call explicitly to avoid ClearSession done by the library function @@ -244,16 +275,16 @@ def test_session_enable_passphrase(client: Client): assert _get_xpub(session) == XPUB_PASSPHRASE_NONE # We clear the session id now, so the passphrase should be asked. - new_session = client.get_session(passphrase="A") + new_session = _get_session(client) assert session_id != new_session.id - assert _get_xpub(new_session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] + assert _get_xpub(new_session, passphrase="A") == XPUB_PASSPHRASES["A"] @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) def test_passphrase_on_device(client: Client): # _init_session(client) - session = client.get_session(passphrase="A") + session = _get_session(client) # try to get xpub with passphrase on host: response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) @@ -269,7 +300,7 @@ def test_passphrase_on_device(client: Client): assert response.xpub == XPUB_PASSPHRASES["A"] # make a new session - session2 = session.client.get_session(passphrase="A") + session2 = _get_session(client) # try to get xpub with passphrase on device: response = session2.call_raw(XPUB_REQUEST) @@ -290,10 +321,10 @@ def test_passphrase_on_device(client: Client): @pytest.mark.models("core") @pytest.mark.setup_client(passphrase=True) +@pytest.mark.uninitialized_session def test_passphrase_always_on_device(client: Client): # Let's start the communication by calling Initialize. - session = client.get_session() - # session_id = _init_session(client) + session = _get_session(client) # Force passphrase entry on Trezor. response = session.call(messages.ApplySettings(passphrase_always_on_device=True)) @@ -314,7 +345,7 @@ def test_passphrase_always_on_device(client: Client): assert response.xpub == XPUB_PASSPHRASE_NONE # In case we want to add a new passphrase we need to send session_id = None. - new_session = client.get_session(passphrase="A") + new_session = _get_session(client) response = new_session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.ButtonRequest) client.debug.input("A") # Input non-empty passphrase. @@ -325,6 +356,7 @@ def test_passphrase_always_on_device(client: Client): @pytest.mark.models("legacy") @pytest.mark.setup_client(passphrase="") +@pytest.mark.uninitialized_session def test_passphrase_on_device_not_possible_on_t1(session: Session): # This setting makes no sense on T1. response = session.call_raw( @@ -342,6 +374,7 @@ def test_passphrase_on_device_not_possible_on_t1(session: Session): @pytest.mark.setup_client(passphrase=True) +@pytest.mark.uninitialized_session def test_passphrase_ack_mismatch(session: Session): response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) @@ -350,7 +383,8 @@ def test_passphrase_ack_mismatch(session: Session): assert response.code == FailureType.DataError -@pytest.mark.setup_client(passphrase="") +@pytest.mark.setup_client(passphrase=True) +@pytest.mark.uninitialized_session def test_passphrase_missing(session: Session): response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) @@ -368,9 +402,10 @@ def test_passphrase_missing(session: Session): @pytest.mark.setup_client(passphrase=True) +@pytest.mark.uninitialized_session def test_passphrase_length(client: Client): def call(passphrase: str, expected_result: bool): - session = client.get_session(passphrase=passphrase) + session = _get_session(client) response = session.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PassphraseRequest) try: @@ -405,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client): device.apply_settings(session, hide_passphrase_from_host=True) passphrase = "abc" - session = client.get_session(passphrase=passphrase) + session = _get_session(client) with session: def input_flow(): @@ -430,15 +465,17 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - result = session.call(XPUB_REQUEST) - assert isinstance(result, messages.PublicKey) - xpub_hidden_passphrase = result.xpub + resp = session.call_raw(XPUB_REQUEST) + resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) + resp = session._callback_button(resp) + assert isinstance(resp, messages.PublicKey) + xpub_hidden_passphrase = resp.xpub # Turning it off device.apply_settings(session, hide_passphrase_from_host=False) # Starting new session, otherwise the passphrase would be cached - session = client.get_session(passphrase=passphrase) + session = _get_session(client) with session: @@ -465,22 +502,29 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - result = session.call(XPUB_REQUEST) - assert isinstance(result, messages.PublicKey) - xpub_shown_passphrase = result.xpub + resp = session.call_raw(XPUB_REQUEST) + assert isinstance(resp, messages.PassphraseRequest) + resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase)) + resp = session._callback_button(resp) + resp = session._callback_button(resp) + assert isinstance(resp, messages.PublicKey) + xpub_shown_passphrase = resp.xpub assert xpub_hidden_passphrase == xpub_shown_passphrase -def _get_xpub_cardano(session: Session, expected_passphrase_req: bool = False): +def _get_xpub_cardano( + session: Session, + passphrase: str | None = None, +): msg = messages.CardanoGetPublicKey( address_n=parse_path("m/44h/1815h/0h/0/0"), derivation_type=messages.CardanoDerivationType.ICARUS, ) response = session.call_raw(msg) - if expected_passphrase_req: + if passphrase is not None: assert isinstance(response, messages.PassphraseRequest) - response = session.call(messages.PassphraseAck(passphrase=session.passphrase)) + response = session.call(messages.PassphraseAck(passphrase=passphrase)) assert isinstance(response, messages.CardanoPublicKey) return response.xpub @@ -497,8 +541,8 @@ def test_cardano_passphrase(client: Client): # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - session = client.get_session(passphrase="B", derive_cardano=True) - assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] + session = _get_session(client, derive_cardano=True) + assert _get_xpub(session, passphrase="B") == XPUB_PASSPHRASES["B"] # The passphrase is now cached for non-Cardano coins. assert _get_xpub(session) == XPUB_PASSPHRASES["B"] @@ -513,14 +557,10 @@ def test_cardano_passphrase(client: Client): assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B # New session will destroy the state - new_session = client.get_session(passphrase="A", derive_cardano=True) - # _init_session(client, derive_cardano=True) + new_session = _get_session(client, derive_cardano=True) # Cardano must ask for passphrase again - assert ( - _get_xpub_cardano(new_session, expected_passphrase_req=True) - == XPUB_CARDANO_PASSPHRASE_A - ) + assert _get_xpub_cardano(new_session, passphrase="A") == XPUB_CARDANO_PASSPHRASE_A # Passphrase is now cached for Cardano assert _get_xpub_cardano(new_session) == XPUB_CARDANO_PASSPHRASE_A