1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-07 16:49:04 +00:00

fix(tests): fix device tests for protocol_v2

This commit is contained in:
M1nd3r 2025-04-15 13:43:55 +02:00
parent c02e34e3d9
commit 524c0c80bf
2 changed files with 29 additions and 17 deletions

View File

@ -19,7 +19,6 @@ import time
import pytest import pytest
from trezorlib import btc, device, messages from trezorlib import btc, device, messages
from trezorlib.debuglink import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
@ -808,11 +807,7 @@ def test_multisession_authorization(client: Client):
) )
# Open a second session. # Open a second session.
if client.protocol_version is ProtocolVersion.V2: session2 = client.get_session()
session_id = b"\x02"
else:
session_id = None
session2 = client.get_session(session_id=session_id)
# Authorize CoinJoin with www.example2.com in session 2. # Authorize CoinJoin with www.example2.com in session 2.
btc.authorize_coinjoin( btc.authorize_coinjoin(

View File

@ -57,7 +57,10 @@ def _assert_protection(client: Client, pin: bool = True, passphrase: bool = True
"""Make sure PIN and passphrase protection have expected values""" """Make sure PIN and passphrase protection have expected values"""
with client: with client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
if client.protocol_version is ProtocolVersion.V1:
session = client.get_seedless_session() session = client.get_seedless_session()
else:
session = client.get_session()
try: try:
session.ensure_unlocked() session.ensure_unlocked()
except exceptions.InvalidSessionError: except exceptions.InvalidSessionError:
@ -119,10 +122,11 @@ def test_passphrase_reporting(session: Session, passphrase):
def test_apply_settings(client: Client): def test_apply_settings(client: Client):
_assert_protection(client) _assert_protection(client)
with client: with client:
v1 = client.protocol_version == ProtocolVersion.V1
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
client.set_expected_responses( client.set_expected_responses(
[ [
messages.Features, (v1, messages.Features),
_pin_request(client), _pin_request(client),
messages.ButtonRequest, messages.ButtonRequest,
messages.Success, messages.Success,
@ -204,11 +208,15 @@ def test_get_public_key(client: Client):
_assert_protection(client) _assert_protection(client)
with client: with client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
expected_responses = [messages.Features, _pin_request(client)] v1 = client.protocol_version == ProtocolVersion.V1
expected_responses = [
if client.protocol_version == ProtocolVersion.V1: (v1, messages.Features),
expected_responses.append(messages.PassphraseRequest) _pin_request(client),
expected_responses.extend([messages.Address, messages.PublicKey]) (v1, messages.PassphraseRequest),
(not v1, messages.Success),
(v1, messages.Address),
messages.PublicKey,
]
client.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
session = client.get_session() session = client.get_session()
@ -220,11 +228,16 @@ def test_get_address(client: Client):
_assert_protection(client) _assert_protection(client)
with client: with client:
v1 = client.protocol_version == ProtocolVersion.V1
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
expected_responses = [messages.Features, _pin_request(client)] expected_responses = [
if client.protocol_version == ProtocolVersion.V1: (v1, messages.Features),
expected_responses.extend([messages.PassphraseRequest, messages.Address]) _pin_request(client),
expected_responses.append(messages.Address) (v1, messages.PassphraseRequest),
(v1, messages.Address),
(not v1, messages.Success),
messages.Address,
]
client.set_expected_responses(expected_responses) client.set_expected_responses(expected_responses)
session = client.get_session() session = client.get_session()
@ -330,6 +343,7 @@ def test_sign_message(client: Client):
_pin_request(client), _pin_request(client),
(v1, messages.PassphraseRequest), (v1, messages.PassphraseRequest),
(v1, messages.Address), (v1, messages.Address),
(not v1, messages.Success),
messages.ButtonRequest, messages.ButtonRequest,
messages.ButtonRequest, messages.ButtonRequest,
messages.MessageSignature, messages.MessageSignature,
@ -389,6 +403,7 @@ def test_verify_message_t2(client: Client):
[ [
(v1, messages.Features), (v1, messages.Features),
_pin_request(client), _pin_request(client),
(not v1, messages.Success),
(v1, messages.PassphraseRequest), (v1, messages.PassphraseRequest),
(v1, messages.Address), (v1, messages.Address),
messages.ButtonRequest, messages.ButtonRequest,
@ -434,6 +449,7 @@ def test_signtx(client: Client):
expected_responses = [ expected_responses = [
(v1, messages.Features), (v1, messages.Features),
_pin_request(client), _pin_request(client),
(not v1, messages.Success),
(v1, messages.PassphraseRequest), (v1, messages.PassphraseRequest),
(v1, messages.Address), (v1, messages.Address),
request_input(0), request_input(0),
@ -475,6 +491,7 @@ def test_unlocked(client: Client):
[ [
(v1, messages.Features), (v1, messages.Features),
_pin_request(client), _pin_request(client),
(not v1, messages.Success),
(v1, messages.Address), (v1, messages.Address),
messages.Address, messages.Address,
] ]