diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index 2b0ab45d03..1ea1cb0a56 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -19,7 +19,6 @@ import time import pytest from trezorlib import btc, device, messages -from trezorlib.debuglink import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure @@ -810,11 +809,7 @@ def test_multisession_authorization(client: Client): ) # Open a second session. - if client.protocol_version is ProtocolVersion.V2: - session_id = b"\x02" - else: - session_id = None - session2 = client.get_session(session_id=session_id) + session2 = client.get_session() # Authorize CoinJoin with www.example2.com in session 2. btc.authorize_coinjoin( diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index e715eeabc6..8c7ec805a1 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -57,7 +57,10 @@ def _assert_protection(client: Client, pin: bool = True, passphrase: bool = True """Make sure PIN and passphrase protection have expected values""" with client: client.use_pin_sequence([PIN4]) - session = client.get_seedless_session() + if client.protocol_version is ProtocolVersion.V1: + session = client.get_seedless_session() + else: + session = client.get_session() try: session.ensure_unlocked() except exceptions.InvalidSessionError: @@ -119,10 +122,11 @@ def test_passphrase_reporting(session: Session, passphrase): def test_apply_settings(client: Client): _assert_protection(client) with client: + v1 = client.protocol_version == ProtocolVersion.V1 client.use_pin_sequence([PIN4]) client.set_expected_responses( [ - messages.Features, + (v1, messages.Features), _pin_request(client), messages.ButtonRequest, messages.Success, @@ -204,11 +208,15 @@ def test_get_public_key(client: Client): _assert_protection(client) with client: client.use_pin_sequence([PIN4]) - expected_responses = [messages.Features, _pin_request(client)] - - if client.protocol_version == ProtocolVersion.V1: - expected_responses.append(messages.PassphraseRequest) - expected_responses.extend([messages.Address, messages.PublicKey]) + v1 = client.protocol_version == ProtocolVersion.V1 + expected_responses = [ + (v1, messages.Features), + _pin_request(client), + (v1, messages.PassphraseRequest), + (not v1, messages.Success), + (v1, messages.Address), + messages.PublicKey, + ] client.set_expected_responses(expected_responses) session = client.get_session() @@ -220,11 +228,16 @@ def test_get_address(client: Client): _assert_protection(client) with client: + v1 = client.protocol_version == ProtocolVersion.V1 client.use_pin_sequence([PIN4]) - expected_responses = [messages.Features, _pin_request(client)] - if client.protocol_version == ProtocolVersion.V1: - expected_responses.extend([messages.PassphraseRequest, messages.Address]) - expected_responses.append(messages.Address) + expected_responses = [ + (v1, messages.Features), + _pin_request(client), + (v1, messages.PassphraseRequest), + (v1, messages.Address), + (not v1, messages.Success), + messages.Address, + ] client.set_expected_responses(expected_responses) session = client.get_session() @@ -331,6 +344,7 @@ def test_sign_message(client: Client): _pin_request(client), (v1, messages.PassphraseRequest), (v1, messages.Address), + (not v1, messages.Success), messages.ButtonRequest, messages.ButtonRequest, messages.MessageSignature, @@ -390,6 +404,7 @@ def test_verify_message_t2(client: Client): [ (v1, messages.Features), _pin_request(client), + (not v1, messages.Success), (v1, messages.PassphraseRequest), (v1, messages.Address), messages.ButtonRequest, @@ -435,6 +450,7 @@ def test_signtx(client: Client): expected_responses = [ (v1, messages.Features), _pin_request(client), + (not v1, messages.Success), (v1, messages.PassphraseRequest), (v1, messages.Address), request_input(0), @@ -476,6 +492,7 @@ def test_unlocked(client: Client): [ (v1, messages.Features), _pin_request(client), + (not v1, messages.Success), (v1, messages.Address), messages.Address, ]