diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index 65e7d3bf94..97cbe9b997 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -53,7 +53,7 @@ def _pin_request(session: Session): def _assert_protection( session: Session, pin: bool = True, passphrase: bool = True -) -> None: +) -> Session: """Make sure PIN and passphrase protection have expected values""" with session, session.client as client: client.use_pin_sequence([PIN4]) @@ -61,8 +61,13 @@ def _assert_protection( client.refresh_features() assert client.features.pin_protection is pin assert client.features.passphrase_protection is passphrase + if session.session_version == Session.THP_V2: + new_session = session.client.get_session() session.lock() - # TODO session.clear_session() + session.end() + if session.session_version == Session.CODEC_V1: + new_session = session.client.get_session() + return Session(new_session) def test_initialize(session: Session): @@ -73,7 +78,7 @@ def test_initialize(session: Session): with session, session.client as client: client.use_pin_sequence([PIN4]) session.ensure_unlocked() - _assert_protection(session) + session = _assert_protection(session) with session: session.set_expected_responses([messages.Features]) session.call(messages.Initialize(session_id=session.id)) @@ -97,7 +102,7 @@ def test_passphrase_reporting(session: Session, passphrase): assert session.features.passphrase_protection is None # on an unlocked device, protection should be reported accurately - _assert_protection(session, pin=True, passphrase=passphrase) + session = _assert_protection(session, pin=True, passphrase=passphrase) # after re-locking, the setting should be hidden again session.lock() @@ -106,7 +111,7 @@ def test_passphrase_reporting(session: Session, passphrase): def test_apply_settings(session: Session): - _assert_protection(session) + session = _assert_protection(session) with session, session.client as client: client.use_pin_sequence([PIN4]) @@ -123,7 +128,7 @@ def test_apply_settings(session: Session): @pytest.mark.models("legacy") def test_change_pin_t1(session: Session): - _assert_protection(session) + session = _assert_protection(session) with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4]) session.set_expected_responses( @@ -141,7 +146,7 @@ def test_change_pin_t1(session: Session): @pytest.mark.models("core") def test_change_pin_t2(session: Session): - _assert_protection(session) + session = _assert_protection(session) with session, session.client as client: client.use_pin_sequence([PIN4, PIN4, PIN4, PIN4]) session.set_expected_responses( @@ -162,14 +167,14 @@ def test_change_pin_t2(session: Session): @pytest.mark.setup_client(pin=None, passphrase=False) def test_ping(session: Session): - _assert_protection(session, pin=False, passphrase=False) + session = _assert_protection(session, pin=False, passphrase=False) with session: session.set_expected_responses([messages.ButtonRequest, messages.Success]) session.call(messages.Ping(message="msg", button_protection=True)) def test_get_entropy(session: Session): - _assert_protection(session) + session = _assert_protection(session) with session, session.client as client: client.use_pin_sequence([PIN4]) session.set_expected_responses( @@ -183,11 +188,12 @@ def test_get_entropy(session: Session): def test_get_public_key(session: Session): - _assert_protection(session) + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - expected_responses = [_pin_request(session)] + if session.session_version == Session.CODEC_V1: expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PublicKey) @@ -197,10 +203,10 @@ def test_get_public_key(session: Session): def test_get_address(session: Session): - _assert_protection(session) + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) - expected_responses = [_pin_request(session)] if session.session_version == Session.CODEC_V1: expected_responses.append(messages.PassphraseRequest) @@ -228,7 +234,7 @@ def test_wipe_device(session: Session): # File "storage/common.py", line 21, in set # RuntimeError: Could not save value - _assert_protection(session) + session = _assert_protection(session) with session: session.set_expected_responses([messages.ButtonRequest, messages.Success]) device.wipe(session) @@ -310,7 +316,8 @@ def test_recovery_device(session: Session): def test_sign_message(session: Session): - _assert_protection(session) + session = _assert_protection(session) + with session, session.client as client: client.use_pin_sequence([PIN4]) @@ -336,7 +343,7 @@ def test_sign_message(session: Session): @pytest.mark.models("legacy") def test_verify_message_t1(session: Session): - _assert_protection(session) + session = _assert_protection(session) with session: session.set_expected_responses( [ @@ -359,7 +366,7 @@ def test_verify_message_t1(session: Session): @pytest.mark.models("core") def test_verify_message_t2(session: Session): - _assert_protection(session) + session = _assert_protection(session) with session, session.client as client: client.use_pin_sequence([PIN4]) session.set_expected_responses( @@ -398,7 +405,7 @@ def test_signtx(session: Session): script_type=messages.OutputScriptType.PAYTOADDRESS, ) - _assert_protection(session) + session = _assert_protection(session) with session, session.client as client: client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] @@ -439,7 +446,8 @@ def test_unlocked(session: Session): session.lock() assert session.features.unlocked is False - _assert_protection(session, passphrase=False) + session = _assert_protection(session, passphrase=False) + with session, session.client as client: client.use_pin_sequence([PIN4]) session.set_expected_responses([_pin_request(session), messages.Address]) @@ -454,7 +462,7 @@ def test_unlocked(session: Session): @pytest.mark.setup_client(pin=None, passphrase=True) def test_passphrase_cached(session: Session): - _assert_protection(session, pin=False) + session = _assert_protection(session, pin=False) with session: if session.session_version == 1: session.set_expected_responses(