1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-23 05:40:57 +00:00

fix(python): fix device tests for test_session

[no changelog]
This commit is contained in:
M1nd3r 2024-11-25 10:40:05 +01:00
parent ddc204a4e1
commit b1821b733f
3 changed files with 116 additions and 96 deletions

View File

@ -124,13 +124,22 @@ class TrezorClient:
from .transport.session import SessionV1, SessionV2 from .transport.session import SessionV1, SessionV2
if isinstance(self.protocol, ProtocolV1): if isinstance(self.protocol, ProtocolV1):
if passphrase is None:
passphrase = ""
return SessionV1.new(self, passphrase, derive_cardano) return SessionV1.new(self, passphrase, derive_cardano)
if isinstance(self.protocol, ProtocolV2): if isinstance(self.protocol, ProtocolV2):
return SessionV2.new(self, passphrase, derive_cardano) return SessionV2.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO raise NotImplementedError # TODO
def resume_session(self, session: Session): def resume_session(self, session: Session):
"""
Note: this function potentially modifies the input session.
"""
from trezorlib.transport.session import SessionV1, SessionV2 from trezorlib.transport.session import SessionV1, SessionV2
from trezorlib.debuglink import SessionDebugWrapper
if isinstance(session, SessionDebugWrapper):
session = session._session
if isinstance(session, SessionV2): if isinstance(session, SessionV2):
return session return session

View File

@ -73,8 +73,8 @@ class Session:
def refresh_features(self) -> None: def refresh_features(self) -> None:
self.client.refresh_features() self.client.refresh_features()
def end(self) -> None: def end(self) -> t.Any:
raise NotImplementedError return self.call(messages.EndSession())
@property @property
def features(self) -> messages.Features: def features(self) -> messages.Features:

View File

@ -19,8 +19,10 @@ import pytest
from trezorlib import cardano, messages, models from trezorlib import cardano, messages, models
from trezorlib.btc import get_public_node from trezorlib.btc import get_public_node
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from trezorlib.transport.session import SessionV1
from ..common import get_test_address from ..common import get_test_address
@ -32,6 +34,10 @@ PIN4 = "1234"
def test_thp_end_session(client: Client): def test_thp_end_session(client: Client):
session = client.get_session() session = client.get_session()
if isinstance(session, SessionV1):
# TODO: This test should be skipped on non-THP builds
return
msg = session.call(messages.EndSession()) msg = session.call(messages.EndSession())
assert isinstance(msg, messages.Success) assert isinstance(msg, messages.Success)
with pytest.raises(TrezorFailure, match="ThpUnallocatedSession"): with pytest.raises(TrezorFailure, match="ThpUnallocatedSession"):
@ -47,100 +53,105 @@ def test_clear_session(client: Client):
] ]
cached_responses = [messages.PublicKey] cached_responses = [messages.PublicKey]
session = Session(client.get_session())
with client: session.lock()
with client, session:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
client.set_expected_responses(init_responses + cached_responses) session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(client, ADDRESS_N).xpub == XPUB assert get_public_node(session, ADDRESS_N).xpub == XPUB
with client: client.resume_session(session)
with session:
# pin and passphrase are cached # pin and passphrase are cached
client.set_expected_responses(cached_responses) session.set_expected_responses(cached_responses)
assert get_public_node(client, ADDRESS_N).xpub == XPUB assert get_public_node(session, ADDRESS_N).xpub == XPUB
client.clear_session() session.lock()
session.end()
session = Session(client.get_session())
# session cache is cleared # session cache is cleared
with client: with client, session:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
client.set_expected_responses(init_responses + cached_responses) session.set_expected_responses(init_responses + cached_responses)
assert get_public_node(client, ADDRESS_N).xpub == XPUB assert get_public_node(session, ADDRESS_N).xpub == XPUB
with client: client.resume_session(session)
with session:
# pin and passphrase are cached # pin and passphrase are cached
client.set_expected_responses(cached_responses) session.set_expected_responses(cached_responses)
assert get_public_node(client, ADDRESS_N).xpub == XPUB assert get_public_node(session, ADDRESS_N).xpub == XPUB
def test_end_session(client: Client): def test_end_session(client: Client):
# client instance starts out not initialized # client instance starts out not initialized
# XXX do we want to change this? # XXX do we want to change this?
assert client.session_id is not None session = client.get_session()
assert session.id is not None
# get_address will succeed # get_address will succeed
with client: with Session(session) as session:
client.set_expected_responses([messages.Address]) session.set_expected_responses([messages.Address])
get_test_address(client) get_test_address(session)
client.end_session() session.end()
assert client.session_id is None # assert client.session_id is None
with pytest.raises(TrezorFailure) as exc: with pytest.raises(TrezorFailure) as exc:
get_test_address(client) get_test_address(session)
assert exc.value.code == messages.FailureType.InvalidSession assert exc.value.code == messages.FailureType.InvalidSession
assert exc.value.message.endswith("Invalid session") assert exc.value.message.endswith("Invalid session")
client.init_device() session = client.get_session()
assert client.session_id is not None assert session.id is not None
with client: with Session(session) as session:
client.set_expected_responses([messages.Address]) session.set_expected_responses([messages.Address])
get_test_address(client) get_test_address(session)
with client: # TODO: is the following valid? I do not think so
# end_session should succeed on empty session too # with Session(session) as session:
client.set_expected_responses([messages.Success] * 2) # # end_session should succeed on empty session too
client.end_session() # session.set_expected_responses([messages.Success] * 2)
client.end_session() # session.end_session()
# session.end_session()
def test_cannot_resume_ended_session(client: Client): def test_cannot_resume_ended_session(client: Client):
session_id = client.session_id session = client.get_session()
with client: session_id = session.id
client.set_expected_responses([messages.Features])
client.init_device(session_id=session_id)
assert session_id == client.session_id session_resumed = client.resume_session(session)
client.end_session() assert session_resumed.id == session_id
with client:
client.set_expected_responses([messages.Features])
client.init_device(session_id=session_id)
assert session_id != client.session_id session.end()
session_resumed2 = client.resume_session(session)
assert session_resumed2.id != session_id
def test_end_session_only_current(client: Client): def test_end_session_only_current(client: Client):
"""test that EndSession only destroys the current session""" """test that EndSession only destroys the current session"""
session_id_a = client.session_id session_a = client.get_session()
client.init_device(new_session=True) session_b = client.get_session()
session_id_b = client.session_id session_b_id = session_b.id
client.end_session() session_b.end()
assert client.session_id is None # assert client.session_id is None
# resume ended session # resume ended session
client.init_device(session_id=session_id_b) session_b_resumed = client.resume_session(session_b)
assert client.session_id != session_id_b assert session_b_resumed.id != session_b_id
# resume first session that was not ended # resume first session that was not ended
client.init_device(session_id=session_id_a) session_a_resumed = client.resume_session(session_a)
assert client.session_id == session_id_a assert session_a_resumed.id == session_a.id
@pytest.mark.setup_client(passphrase=True) @pytest.mark.setup_client(passphrase=True)
def test_session_recycling(client: Client): def test_session_recycling(client: Client):
session_id_orig = client.session_id session = Session(client.get_session(passphrase="TREZOR"))
with client: with client, session:
client.set_expected_responses( session.set_expected_responses(
[ [
messages.PassphraseRequest, messages.PassphraseRequest,
messages.ButtonRequest, messages.ButtonRequest,
@ -149,20 +160,21 @@ def test_session_recycling(client: Client):
] ]
) )
client.use_passphrase("TREZOR") client.use_passphrase("TREZOR")
address = get_test_address(client) address = get_test_address(session)
# create and close 100 sessions - more than the session limit # create and close 100 sessions - more than the session limit
for _ in range(100): for _ in range(100):
client.init_device(new_session=True) session_x = client.get_session()
client.end_session() session_x.end()
# it should still be possible to resume the original session # it should still be possible to resume the original session
with client: # TODO imo not True anymore
# passphrase should still be cached # with client, session:
client.set_expected_responses([messages.Features, messages.Address]) # # passphrase should still be cached
client.use_passphrase("TREZOR") # session.set_expected_responses([messages.Features, messages.Address])
client.init_device(session_id=session_id_orig) # client.use_passphrase("TREZOR")
assert address == get_test_address(client) # client.resume_session(session)
# assert address == get_test_address(session)
@pytest.mark.altcoin @pytest.mark.altcoin
@ -170,18 +182,19 @@ def test_session_recycling(client: Client):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_derive_cardano_empty_session(client: Client): def test_derive_cardano_empty_session(client: Client):
# start new session # start new session
client.init_device(new_session=True) session = client.get_session(derive_cardano=True)
session_id = client.session_id # session_id = client.session_id
# restarting same session should go well # restarting same session should go well
client.init_device() session2 = client.resume_session(session)
assert session_id == client.session_id assert session.id == session2.id
# restarting same session should go well with any setting # restarting same session should go well with any setting
client.init_device(derive_cardano=False) # TODO I do not think that it holds True now
assert session_id == client.session_id # client.init_device(derive_cardano=False)
client.init_device(derive_cardano=True) # assert session_id == client.session_id
assert session_id == client.session_id # client.init_device(derive_cardano=True)
# assert session_id == client.session_id
@pytest.mark.altcoin @pytest.mark.altcoin
@ -189,43 +202,41 @@ def test_derive_cardano_empty_session(client: Client):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_derive_cardano_running_session(client: Client): def test_derive_cardano_running_session(client: Client):
# start new session # start new session
client.init_device(new_session=True) session = client.get_session(derive_cardano=False)
session_id = client.session_id
# force derivation of seed # force derivation of seed
get_test_address(client) get_test_address(session)
# session should not have Cardano capability # session should not have Cardano capability
with pytest.raises(TrezorFailure, match="not enabled"): with pytest.raises(TrezorFailure, match="not enabled"):
cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) cardano.get_public_key(session, parse_path("m/44h/1815h/0h"))
# restarting same session should go well # restarting same session should go well
client.init_device() session2 = client.resume_session(session)
assert session_id == client.session_id assert session.id == session2.id
# restarting same session should go well if we _don't_ want to derive cardano # TODO restarting same session should go well if we _don't_ want to derive cardano
client.init_device(derive_cardano=False) # # client.init_device(derive_cardano=False)
assert session_id == client.session_id # # assert session_id == client.session_id
# restarting with derive_cardano=True should kill old session and create new one # restarting with derive_cardano=True should kill old session and create new one
client.init_device(derive_cardano=True) session3 = client.get_session(derive_cardano=True)
assert session_id != client.session_id assert session3.id != session.id
session_id = client.session_id
# new session should have Cardano capability # new session should have Cardano capability
cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) cardano.get_public_key(session3, parse_path("m/44h/1815h/0h"))
# restarting with derive_cardano=True should keep same session # restarting with derive_cardano=True should keep same session
client.init_device(derive_cardano=True) session4 = client.resume_session(session3)
assert session_id == client.session_id assert session4.id == session3.id
# restarting with no setting should keep same session # # restarting with no setting should keep same session
client.init_device() # client.init_device()
assert session_id == client.session_id # assert session_id == client.session_id
# restarting with derive_cardano=False should kill old session and create new one # # restarting with derive_cardano=False should kill old session and create new one
client.init_device(derive_cardano=False) # client.init_device(derive_cardano=False)
assert session_id != client.session_id # assert session_id != client.session_id
with pytest.raises(TrezorFailure, match="not enabled"): # with pytest.raises(TrezorFailure, match="not enabled"):
cardano.get_public_key(client, parse_path("m/44h/1815h/0h")) # cardano.get_public_key(client, parse_path("m/44h/1815h/0h"))