1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-19 14:38:47 +00:00

fixup! test: update device tests

This commit is contained in:
M1nd3r 2025-03-26 16:50:28 +01:00
parent 285ba53db2
commit c1ec1d38ba
2 changed files with 155 additions and 97 deletions

View File

@ -16,11 +16,11 @@
import pytest import pytest
from trezorlib import cardano, messages, models from trezorlib import cardano, exceptions, messages, models
from trezorlib.btc import get_public_node
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import Address, parse_path
from trezorlib.transport.session import Session, SessionV1
from ..common import get_test_address from ..common import get_test_address
@ -30,6 +30,22 @@ XPUB = "xpub6BiVtCpG9fQPxnPmHXG8PhtzQdWC2Su4qWu6XW9tpWFYhxydCLJGrWBJZ5H6qTAHdPQ7
PIN4 = "1234" PIN4 = "1234"
def _get_public_node(
session: "Session",
address: "Address",
passphrase: str | None = None,
) -> messages.PublicKey:
resp = session.call_raw(
messages.GetPublicKey(address_n=address),
)
if isinstance(resp, messages.ButtonRequest):
resp = session._callback_button(resp)
if passphrase is not None:
resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase))
return resp
@pytest.mark.setup_client(pin=PIN4, passphrase="") @pytest.mark.setup_client(pin=PIN4, passphrase="")
def test_clear_session(client: Client): def test_clear_session(client: Client):
is_t1 = client.model is models.T1B1 is_t1 = client.model is models.T1B1
@ -44,13 +60,13 @@ def test_clear_session(client: Client):
with session: with session:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses) session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N, passphrase="").xpub == XPUB
session.resume() session.resume()
with session: with session:
# pin and passphrase are cached # pin and passphrase are cached
session.set_expected_responses(cached_responses) session.set_expected_responses(cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N).xpub == XPUB
session.lock() session.lock()
session.end() session.end()
@ -60,13 +76,13 @@ def test_clear_session(client: Client):
with session: with session:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
session.set_expected_responses(init_responses + cached_responses) session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N, passphrase="").xpub == XPUB
session.resume() session.resume()
with session: with session:
# pin and passphrase are cached # pin and passphrase are cached
session.set_expected_responses(cached_responses) session.set_expected_responses(cached_responses)
assert get_public_node(session, ADDRESS_N).xpub == XPUB assert _get_public_node(session, ADDRESS_N).xpub == XPUB
def test_end_session(client: Client): def test_end_session(client: Client):
@ -109,9 +125,10 @@ def test_cannot_resume_ended_session(client: Client):
assert session.id == session_id assert session.id == session_id
session.end() session.end()
session.resume() with pytest.raises(exceptions.FailedSessionResumption) as e:
session.resume()
assert session.id != session_id assert e.value.received_session_id != session_id
def test_end_session_only_current(client: Client): def test_end_session_only_current(client: Client):
@ -124,8 +141,10 @@ def test_end_session_only_current(client: Client):
# assert client.session_id is None # assert client.session_id is None
# resume ended session # resume ended session
session_b.resume() with pytest.raises(exceptions.FailedSessionResumption) as e:
assert session_b.id != session_b_id session_b.resume()
assert e.value.received_session_id != session_b_id
# resume first session that was not ended # resume first session that was not ended
session_a.resume() session_a.resume()
@ -136,14 +155,7 @@ def test_end_session_only_current(client: Client):
def test_session_recycling(client: Client): def test_session_recycling(client: Client):
session = client.get_session(passphrase="TREZOR") session = client.get_session(passphrase="TREZOR")
with session: with session:
session.set_expected_responses( session.set_expected_responses([messages.Address])
[
messages.PassphraseRequest,
messages.ButtonRequest,
messages.ButtonRequest,
messages.Address,
]
)
address = get_test_address(session) address = get_test_address(session)
# create and close 100 sessions - more than the session limit # create and close 100 sessions - more than the session limit
@ -174,10 +186,10 @@ def test_derive_cardano_empty_session(client: Client):
assert session.id == session_id assert session.id == session_id
# restarting same session should go well with any setting # restarting same session should go well with any setting
session_3 = client.get_session(session_id=session_id, derive_cardano=False) session.init_session(derive_cardano=False)
assert session_id == session_3.id assert session_id == session.id
session_4 = client.get_session(session_id=session_id, derive_cardano=True) session.init_session(derive_cardano=True)
assert session_id == session_4.id assert session_id == session.id
@pytest.mark.altcoin @pytest.mark.altcoin
@ -199,24 +211,30 @@ def test_derive_cardano_running_session(client: Client):
assert session.id == session_id assert session.id == session_id
# restarting same session should go well if we _don't_ want to derive cardano # restarting same session should go well if we _don't_ want to derive cardano
session_3 = client.get_session(session_id=session_id, derive_cardano=False) session.init_session(derive_cardano=False)
assert session_3.id == session.id assert session.id == session_id
# restarting with derive_cardano=True should kill old session and create new one # restarting with derive_cardano=True should kill old session and create new one
session_4 = client.get_session(derive_cardano=True) with pytest.raises(exceptions.FailedSessionResumption) as e:
session_4_id = session_4.id session.init_session(derive_cardano=True)
assert session_4_id != session.id session_2 = SessionV1(client, e.value.received_session_id)
session_2.derive_cardano = True
session_2_id = session_2.id
assert session_2_id != session.id
# new session should have Cardano capability # new session should have Cardano capability
cardano.get_public_key(session_4, parse_path("m/44h/1815h/0h")) cardano.get_public_key(session_2, parse_path("m/44h/1815h/0h"))
# restarting with derive_cardano=True should keep same session # restarting with derive_cardano=True should keep same session
session_4.resume() session_2.resume()
assert session_4.id == session_4_id assert session_2.id == session_2_id
# restarting with derive_cardano=False should kill old session and create new one # restarting with derive_cardano=False should kill old session and create new one
session_6 = client.get_session(session_id=session_4.id, derive_cardano=False) with pytest.raises(exceptions.FailedSessionResumption) as e:
assert session_4.id != session_6.id session_2.init_session(derive_cardano=False)
session_3 = SessionV1(client, e.value.received_session_id)
assert session_2.id != session_3.id
with pytest.raises(TrezorFailure, match="not enabled"): with pytest.raises(TrezorFailure, match="not enabled"):
cardano.get_public_key(session_6, parse_path("m/44h/1815h/0h")) cardano.get_public_key(session_3, parse_path("m/44h/1815h/0h"))

View File

@ -25,6 +25,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.messages import FailureType, SafetyCheckLevel from trezorlib.messages import FailureType, SafetyCheckLevel
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from trezorlib.transport.session import SessionV1
from .. import translations as TR from .. import translations as TR
@ -52,10 +53,10 @@ SESSIONS_STORED = 10
def _get_xpub( def _get_xpub(
session: Session, session: Session,
expected_passphrase_req: bool = False, passphrase: str | None = None,
): ):
"""Get XPUB and check that the appropriate passphrase flow has happened.""" """Get XPUB and check that the appropriate passphrase flow has happened."""
if expected_passphrase_req: if passphrase is not None:
expected_responses = [ expected_responses = [
messages.PassphraseRequest, messages.PassphraseRequest,
messages.ButtonRequest, messages.ButtonRequest,
@ -67,18 +68,35 @@ def _get_xpub(
with session: with session:
session.set_expected_responses(expected_responses) session.set_expected_responses(expected_responses)
result = session.call(XPUB_REQUEST) result = session.call_raw(XPUB_REQUEST)
if passphrase is not None:
result = session.call_raw(messages.PassphraseAck(passphrase=passphrase))
while isinstance(result, messages.ButtonRequest):
result = session._callback_button(result)
return result.xpub return result.xpub
def _get_session(client: Client, session_id=None, derive_cardano=False) -> Session:
"""Call Initialize, check and return the session."""
from trezorlib.transport.session import SessionV1
session = SessionV1.new(
client=client, derive_cardano=derive_cardano, session_id=session_id
)
return Session(session)
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
def test_session_with_passphrase(client: Client): def test_session_with_passphrase(client: Client):
session = client.get_session(passphrase="A") # session = client.get_session(passphrase="A")
session = _get_session(client)
session_id = session.id session_id = session.id
# GetPublicKey requires passphrase and since it is not cached, # GetPublicKey requires passphrase and since it is not cached,
# Trezor will prompt for it. # Trezor will prompt for it.
assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] assert _get_xpub(session, passphrase="A") == XPUB_PASSPHRASES["A"]
# Call Initialize again, this time with the received session id and then call # Call Initialize again, this time with the received session id and then call
# GetPublicKey. The passphrase should be cached now so Trezor must # GetPublicKey. The passphrase should be cached now so Trezor must
@ -89,30 +107,42 @@ def test_session_with_passphrase(client: Client):
# If we set session id in Initialize to None, the cache will be cleared # If we set session id in Initialize to None, the cache will be cleared
# and Trezor will ask for the passphrase again. # and Trezor will ask for the passphrase again.
session3 = client.get_session(passphrase="A") session_2 = _get_session(client)
assert session3.id != session_id assert session_2.id != session_id
assert _get_xpub(session3, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] assert _get_xpub(session_2, passphrase="A") == XPUB_PASSPHRASES["A"]
# Unknown session id has the same result as setting it to None. # Unknown session id leads to FailedSessionResumption in trezorlib.
session4 = client.get_session(session_id=b"X" * 32, passphrase="A") # Trezor ignores the invalid session_id and creates a new session
assert session4.id != b"X" * 32 with pytest.raises(exceptions.FailedSessionResumption) as e:
assert session4.id != session_id _get_session(client, session_id=b"X" * 32)
assert session4.id != session3.id
assert _get_xpub(session4, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] session_3 = Session(SessionV1(client, e.value.received_session_id))
assert session_3.id is not None
assert len(session_3.id) == 32
assert session_3.id != b"X" * 32
assert session_3.id != session_id
assert session_3.id != session_2.id
assert _get_xpub(session_3, passphrase="A") == XPUB_PASSPHRASES["A"]
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
def test_multiple_sessions(client: Client): def test_multiple_sessions(client: Client):
# start SESSIONS_STORED sessions # start SESSIONS_STORED sessions
SESSIONS_STORED = 10
session_ids = [] session_ids = []
sessions = [] sessions = []
for _ in range(SESSIONS_STORED): for _ in range(SESSIONS_STORED):
session = client.get_session() session = _get_session(client)
sessions.append(session) sessions.append(session)
session_ids.append(session.id) session_ids.append(session.id)
# Resume each session # Resume each session
for i in range(SESSIONS_STORED): for i in range(SESSIONS_STORED):
if i == 0:
pass
# raise Exception(sessions[i]._session.id)
sessions[i].resume() sessions[i].resume()
assert session_ids[i] == sessions[i].id assert session_ids[i] == sessions[i].id
@ -125,8 +155,9 @@ def test_multiple_sessions(client: Client):
assert session_ids[i] == sessions[i].id assert session_ids[i] == sessions[i].id
# Resuming session 0 will not work # Resuming session 0 will not work
sessions[0].resume() with pytest.raises(exceptions.FailedSessionResumption) as e:
assert session_ids[0] != sessions[0].id sessions[0].resume()
assert session_ids[0] != e.value.received_session_id
# New session bumped out the least-recently-used anonymous session. # New session bumped out the least-recently-used anonymous session.
# Resuming session 1 through SESSIONS_STORED will still work # Resuming session 1 through SESSIONS_STORED will still work
@ -135,20 +166,21 @@ def test_multiple_sessions(client: Client):
assert session_ids[i] == sessions[i].id assert session_ids[i] == sessions[i].id
# Creating a new session replaces session_ids[0] again # Creating a new session replaces session_ids[0] again
client.get_session() _get_session(client)
# Resuming all sessions one by one will in turn bump out the previous session. # Resuming all sessions one by one will in turn bump out the previous session.
for i in range(SESSIONS_STORED): for i in range(SESSIONS_STORED):
sessions[i].resume() with pytest.raises(exceptions.FailedSessionResumption) as e:
assert session_ids[i] != sessions[i].id sessions[i].resume()
assert session_ids[i] != e.value.received_session_id
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
def test_multiple_passphrases(client: Client): def test_multiple_passphrases(client: Client):
# start a session # start a session
session_a = client.get_session(passphrase="A") session_a = _get_session(client)
session_a_id = session_a.id session_a_id = session_a.id
assert _get_xpub(session_a, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] assert _get_xpub(session_a, passphrase="A") == XPUB_PASSPHRASES["A"]
# start it again wit the same session id # start it again wit the same session id
session_a.resume() session_a.resume()
# session is the same # session is the same
@ -157,10 +189,10 @@ def test_multiple_passphrases(client: Client):
assert _get_xpub(session_a) == XPUB_PASSPHRASES["A"] assert _get_xpub(session_a) == XPUB_PASSPHRASES["A"]
# start a second session # start a second session
session_b = client.get_session(passphrase="B") session_b = _get_session(client)
session_b_id = session_b.id session_b_id = session_b.id
# new session -> new session id and passphrase prompt # new session -> new session id and passphrase prompt
assert _get_xpub(session_b, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] assert _get_xpub(session_b, passphrase="B") == XPUB_PASSPHRASES["B"]
# provide the same session id -> must not ask for passphrase again. # provide the same session id -> must not ask for passphrase again.
session_b.resume() session_b.resume()
@ -188,11 +220,11 @@ def test_max_sessions_with_passphrases(client: Client):
session_ids = {} session_ids = {}
sessions = {} sessions = {}
for passphrase, xpub in XPUB_PASSPHRASES.items(): for passphrase, xpub in XPUB_PASSPHRASES.items():
session = client.get_session(passphrase=passphrase) session = _get_session(client)
assert session.id not in session_ids.values() assert session.id not in session_ids.values()
session_ids[passphrase] = session.id session_ids[passphrase] = session.id
sessions[passphrase] = session sessions[passphrase] = session
assert _get_xpub(session, expected_passphrase_req=True) == xpub assert _get_xpub(session, passphrase=passphrase) == xpub
# passphrase is not prompted for the started the sessions, regardless the order # passphrase is not prompted for the started the sessions, regardless the order
# let's try 20 different orderings # let's try 20 different orderings
@ -212,25 +244,24 @@ def test_max_sessions_with_passphrases(client: Client):
assert _get_xpub(sessions[passphrase]) == XPUB_PASSPHRASES[passphrase] assert _get_xpub(sessions[passphrase]) == XPUB_PASSPHRASES[passphrase]
# creating one more session will exceed the limit # creating one more session will exceed the limit
new_session = client.get_session(passphrase="XX") new_session = _get_session(client)
# new session asks for passphrase # new session asks for passphrase
_get_xpub(new_session, expected_passphrase_req=True) _get_xpub(new_session, passphrase="XX")
# restoring the sessions in reverse will evict the next-up session # restoring the sessions in reverse will evict the next-up session
for passphrase in reversed(passphrases): for passphrase in reversed(passphrases):
sessions[passphrase].resume() with pytest.raises(exceptions.FailedSessionResumption) as e:
_get_xpub( sessions[passphrase].resume()
sessions[passphrase], sessions[passphrase] = Session(SessionV1(client, e.value.received_session_id))
expected_passphrase_req=True, _get_xpub(sessions[passphrase], passphrase=passphrase) # passphrase is prompted
) # passphrase is prompted
def test_session_enable_passphrase(client: Client): def test_session_enable_passphrase(client: Client):
# Let's start the communication by calling Initialize. # Let's start the communication by calling Initialize.
session = client.get_session(passphrase="") session = _get_session(client)
# Trezor will not prompt for passphrase because it is turned off. # Trezor will not prompt for passphrase because it is turned off.
assert _get_xpub(session, expected_passphrase_req=False) == XPUB_PASSPHRASE_NONE assert _get_xpub(session) == XPUB_PASSPHRASE_NONE
# Turn on passphrase. # Turn on passphrase.
# Emit the call explicitly to avoid ClearSession done by the library function # Emit the call explicitly to avoid ClearSession done by the library function
@ -244,16 +275,16 @@ def test_session_enable_passphrase(client: Client):
assert _get_xpub(session) == XPUB_PASSPHRASE_NONE assert _get_xpub(session) == XPUB_PASSPHRASE_NONE
# We clear the session id now, so the passphrase should be asked. # We clear the session id now, so the passphrase should be asked.
new_session = client.get_session(passphrase="A") new_session = _get_session(client)
assert session_id != new_session.id assert session_id != new_session.id
assert _get_xpub(new_session, expected_passphrase_req=True) == XPUB_PASSPHRASES["A"] assert _get_xpub(new_session, passphrase="A") == XPUB_PASSPHRASES["A"]
@pytest.mark.models("core") @pytest.mark.models("core")
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
def test_passphrase_on_device(client: Client): def test_passphrase_on_device(client: Client):
# _init_session(client) # _init_session(client)
session = client.get_session(passphrase="A") session = _get_session(client)
# try to get xpub with passphrase on host: # try to get xpub with passphrase on host:
response = session.call_raw(XPUB_REQUEST) response = session.call_raw(XPUB_REQUEST)
assert isinstance(response, messages.PassphraseRequest) assert isinstance(response, messages.PassphraseRequest)
@ -269,7 +300,7 @@ def test_passphrase_on_device(client: Client):
assert response.xpub == XPUB_PASSPHRASES["A"] assert response.xpub == XPUB_PASSPHRASES["A"]
# make a new session # make a new session
session2 = session.client.get_session(passphrase="A") session2 = _get_session(client)
# try to get xpub with passphrase on device: # try to get xpub with passphrase on device:
response = session2.call_raw(XPUB_REQUEST) response = session2.call_raw(XPUB_REQUEST)
@ -290,10 +321,10 @@ def test_passphrase_on_device(client: Client):
@pytest.mark.models("core") @pytest.mark.models("core")
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
@pytest.mark.uninitialized_session
def test_passphrase_always_on_device(client: Client): def test_passphrase_always_on_device(client: Client):
# Let's start the communication by calling Initialize. # Let's start the communication by calling Initialize.
session = client.get_session() session = _get_session(client)
# session_id = _init_session(client)
# Force passphrase entry on Trezor. # Force passphrase entry on Trezor.
response = session.call(messages.ApplySettings(passphrase_always_on_device=True)) response = session.call(messages.ApplySettings(passphrase_always_on_device=True))
@ -314,7 +345,7 @@ def test_passphrase_always_on_device(client: Client):
assert response.xpub == XPUB_PASSPHRASE_NONE assert response.xpub == XPUB_PASSPHRASE_NONE
# In case we want to add a new passphrase we need to send session_id = None. # In case we want to add a new passphrase we need to send session_id = None.
new_session = client.get_session(passphrase="A") new_session = _get_session(client)
response = new_session.call_raw(XPUB_REQUEST) response = new_session.call_raw(XPUB_REQUEST)
assert isinstance(response, messages.ButtonRequest) assert isinstance(response, messages.ButtonRequest)
client.debug.input("A") # Input non-empty passphrase. client.debug.input("A") # Input non-empty passphrase.
@ -325,6 +356,7 @@ def test_passphrase_always_on_device(client: Client):
@pytest.mark.models("legacy") @pytest.mark.models("legacy")
@pytest.mark.setup_client(passphrase="") @pytest.mark.setup_client(passphrase="")
@pytest.mark.uninitialized_session
def test_passphrase_on_device_not_possible_on_t1(session: Session): def test_passphrase_on_device_not_possible_on_t1(session: Session):
# This setting makes no sense on T1. # This setting makes no sense on T1.
response = session.call_raw( response = session.call_raw(
@ -342,6 +374,7 @@ def test_passphrase_on_device_not_possible_on_t1(session: Session):
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
@pytest.mark.uninitialized_session
def test_passphrase_ack_mismatch(session: Session): def test_passphrase_ack_mismatch(session: Session):
response = session.call_raw(XPUB_REQUEST) response = session.call_raw(XPUB_REQUEST)
assert isinstance(response, messages.PassphraseRequest) assert isinstance(response, messages.PassphraseRequest)
@ -350,7 +383,8 @@ def test_passphrase_ack_mismatch(session: Session):
assert response.code == FailureType.DataError assert response.code == FailureType.DataError
@pytest.mark.setup_client(passphrase="") @pytest.mark.setup_client(passphrase=True)
@pytest.mark.uninitialized_session
def test_passphrase_missing(session: Session): def test_passphrase_missing(session: Session):
response = session.call_raw(XPUB_REQUEST) response = session.call_raw(XPUB_REQUEST)
assert isinstance(response, messages.PassphraseRequest) assert isinstance(response, messages.PassphraseRequest)
@ -368,9 +402,10 @@ def test_passphrase_missing(session: Session):
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
@pytest.mark.uninitialized_session
def test_passphrase_length(client: Client): def test_passphrase_length(client: Client):
def call(passphrase: str, expected_result: bool): def call(passphrase: str, expected_result: bool):
session = client.get_session(passphrase=passphrase) session = _get_session(client)
response = session.call_raw(XPUB_REQUEST) response = session.call_raw(XPUB_REQUEST)
assert isinstance(response, messages.PassphraseRequest) assert isinstance(response, messages.PassphraseRequest)
try: try:
@ -405,7 +440,7 @@ def test_hide_passphrase_from_host(client: Client):
device.apply_settings(session, hide_passphrase_from_host=True) device.apply_settings(session, hide_passphrase_from_host=True)
passphrase = "abc" passphrase = "abc"
session = client.get_session(passphrase=passphrase) session = _get_session(client)
with session: with session:
def input_flow(): def input_flow():
@ -430,15 +465,17 @@ def test_hide_passphrase_from_host(client: Client):
messages.PublicKey, messages.PublicKey,
] ]
) )
result = session.call(XPUB_REQUEST) resp = session.call_raw(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey) resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase))
xpub_hidden_passphrase = result.xpub resp = session._callback_button(resp)
assert isinstance(resp, messages.PublicKey)
xpub_hidden_passphrase = resp.xpub
# Turning it off # Turning it off
device.apply_settings(session, hide_passphrase_from_host=False) device.apply_settings(session, hide_passphrase_from_host=False)
# Starting new session, otherwise the passphrase would be cached # Starting new session, otherwise the passphrase would be cached
session = client.get_session(passphrase=passphrase) session = _get_session(client)
with session: with session:
@ -465,22 +502,29 @@ def test_hide_passphrase_from_host(client: Client):
messages.PublicKey, messages.PublicKey,
] ]
) )
result = session.call(XPUB_REQUEST) resp = session.call_raw(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey) assert isinstance(resp, messages.PassphraseRequest)
xpub_shown_passphrase = result.xpub resp = session.call_raw(messages.PassphraseAck(passphrase=passphrase))
resp = session._callback_button(resp)
resp = session._callback_button(resp)
assert isinstance(resp, messages.PublicKey)
xpub_shown_passphrase = resp.xpub
assert xpub_hidden_passphrase == xpub_shown_passphrase assert xpub_hidden_passphrase == xpub_shown_passphrase
def _get_xpub_cardano(session: Session, expected_passphrase_req: bool = False): def _get_xpub_cardano(
session: Session,
passphrase: str | None = None,
):
msg = messages.CardanoGetPublicKey( msg = messages.CardanoGetPublicKey(
address_n=parse_path("m/44h/1815h/0h/0/0"), address_n=parse_path("m/44h/1815h/0h/0/0"),
derivation_type=messages.CardanoDerivationType.ICARUS, derivation_type=messages.CardanoDerivationType.ICARUS,
) )
response = session.call_raw(msg) response = session.call_raw(msg)
if expected_passphrase_req: if passphrase is not None:
assert isinstance(response, messages.PassphraseRequest) assert isinstance(response, messages.PassphraseRequest)
response = session.call(messages.PassphraseAck(passphrase=session.passphrase)) response = session.call(messages.PassphraseAck(passphrase=passphrase))
assert isinstance(response, messages.CardanoPublicKey) assert isinstance(response, messages.CardanoPublicKey)
return response.xpub return response.xpub
@ -497,8 +541,8 @@ def test_cardano_passphrase(client: Client):
# GetPublicKey requires passphrase and since it is not cached, # GetPublicKey requires passphrase and since it is not cached,
# Trezor will prompt for it. # Trezor will prompt for it.
session = client.get_session(passphrase="B", derive_cardano=True) session = _get_session(client, derive_cardano=True)
assert _get_xpub(session, expected_passphrase_req=True) == XPUB_PASSPHRASES["B"] assert _get_xpub(session, passphrase="B") == XPUB_PASSPHRASES["B"]
# The passphrase is now cached for non-Cardano coins. # The passphrase is now cached for non-Cardano coins.
assert _get_xpub(session) == XPUB_PASSPHRASES["B"] assert _get_xpub(session) == XPUB_PASSPHRASES["B"]
@ -513,14 +557,10 @@ def test_cardano_passphrase(client: Client):
assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B assert _get_xpub_cardano(session) == XPUB_CARDANO_PASSPHRASE_B
# New session will destroy the state # New session will destroy the state
new_session = client.get_session(passphrase="A", derive_cardano=True) new_session = _get_session(client, derive_cardano=True)
# _init_session(client, derive_cardano=True)
# Cardano must ask for passphrase again # Cardano must ask for passphrase again
assert ( assert _get_xpub_cardano(new_session, passphrase="A") == XPUB_CARDANO_PASSPHRASE_A
_get_xpub_cardano(new_session, expected_passphrase_req=True)
== XPUB_CARDANO_PASSPHRASE_A
)
# Passphrase is now cached for Cardano # Passphrase is now cached for Cardano
assert _get_xpub_cardano(new_session) == XPUB_CARDANO_PASSPHRASE_A assert _get_xpub_cardano(new_session) == XPUB_CARDANO_PASSPHRASE_A