1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-02 02:41:28 +00:00

tests(python): improve thp device tests

[no changelog]
This commit is contained in:
M1nd3r 2025-01-29 21:41:36 +01:00
parent b3cb270249
commit 41abafc288
5 changed files with 120 additions and 171 deletions

View File

@ -103,7 +103,7 @@ message ThpPairingRequestApproved{
* @next ThpCodeEntryCommitment * @next ThpCodeEntryCommitment
*/ */
message ThpSelectMethod { message ThpSelectMethod {
optional ThpPairingMethod selected_pairing_method = 1 [default=NFC];; optional ThpPairingMethod selected_pairing_method = 1;
} }
/** /**

View File

@ -6247,7 +6247,7 @@ if TYPE_CHECKING:
return isinstance(msg, cls) return isinstance(msg, cls)
class ThpSelectMethod(protobuf.MessageType): class ThpSelectMethod(protobuf.MessageType):
selected_pairing_method: "ThpPairingMethod" selected_pairing_method: "ThpPairingMethod | None"
def __init__( def __init__(
self, self,

View File

@ -7985,13 +7985,13 @@ class ThpPairingRequestApproved(protobuf.MessageType):
class ThpSelectMethod(protobuf.MessageType): class ThpSelectMethod(protobuf.MessageType):
MESSAGE_WIRE_TYPE = 1008 MESSAGE_WIRE_TYPE = 1008
FIELDS = { FIELDS = {
1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=ThpPairingMethod.NFC), 1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None),
} }
def __init__( def __init__(
self, self,
*, *,
selected_pairing_method: Optional["ThpPairingMethod"] = ThpPairingMethod.NFC, selected_pairing_method: Optional["ThpPairingMethod"] = None,
) -> None: ) -> None:
self.selected_pairing_method = selected_pairing_method self.selected_pairing_method = selected_pairing_method

View File

@ -992,8 +992,8 @@ impl ThpSelectMethod {
pub fn selected_pairing_method(&self) -> ThpPairingMethod { pub fn selected_pairing_method(&self) -> ThpPairingMethod {
match self.selected_pairing_method { match self.selected_pairing_method {
Some(e) => e.enum_value_or(ThpPairingMethod::NFC), Some(e) => e.enum_value_or(ThpPairingMethod::SkipPairing),
None => ThpPairingMethod::NFC, None => ThpPairingMethod::SkipPairing,
} }
} }
@ -4105,24 +4105,24 @@ static file_descriptor_proto_data: &'static [u8] = b"\
on_device\x18\x02\x20\x01(\x08R\x08onDevice\x12%\n\x0ederive_cardano\x18\ on_device\x18\x02\x20\x01(\x08R\x08onDevice\x12%\n\x0ederive_cardano\x18\
\x03\x20\x01(\x08R\rderiveCardano\"0\n\x11ThpPairingRequest\x12\x1b\n\th\ \x03\x20\x01(\x08R\rderiveCardano\"0\n\x11ThpPairingRequest\x12\x1b\n\th\
ost_name\x18\x01\x20\x01(\tR\x08hostName\"\x1b\n\x19ThpPairingRequestApp\ ost_name\x18\x01\x20\x01(\tR\x08hostName\"\x1b\n\x19ThpPairingRequestApp\
roved\"x\n\x0fThpSelectMethod\x12e\n\x17selected_pairing_method\x18\x01\ roved\"s\n\x0fThpSelectMethod\x12`\n\x17selected_pairing_method\x18\x01\
\x20\x01(\x0e2(.hw.trezor.messages.thp.ThpPairingMethod:\x03NFCR\x15sele\ \x20\x01(\x0e2(.hw.trezor.messages.thp.ThpPairingMethodR\x15selectedPair\
ctedPairingMethod\"\x20\n\x1eThpPairingPreparationsFinished\"8\n\x16ThpC\ ingMethod\"\x20\n\x1eThpPairingPreparationsFinished\"8\n\x16ThpCodeEntry\
odeEntryCommitment\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitme\ Commitment\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitment\"5\n\
nt\"5\n\x15ThpCodeEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\ \x15ThpCodeEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\x0cR\tch\
\x0cR\tchallenge\"P\n\x17ThpCodeEntryCpaceTrezor\x125\n\x17cpace_trezor_\ allenge\"P\n\x17ThpCodeEntryCpaceTrezor\x125\n\x17cpace_trezor_public_ke\
public_key\x18\x01\x20\x01(\x0cR\x14cpaceTrezorPublicKey\"_\n\x18ThpCode\ y\x18\x01\x20\x01(\x0cR\x14cpaceTrezorPublicKey\"_\n\x18ThpCodeEntryCpac\
EntryCpaceHostTag\x121\n\x15cpace_host_public_key\x18\x01\x20\x01(\x0cR\ eHostTag\x121\n\x15cpace_host_public_key\x18\x01\x20\x01(\x0cR\x12cpaceH\
\x12cpaceHostPublicKey\x12\x10\n\x03tag\x18\x02\x20\x01(\x0cR\x03tag\",\ ostPublicKey\x12\x10\n\x03tag\x18\x02\x20\x01(\x0cR\x03tag\",\n\x12ThpCo\
\n\x12ThpCodeEntrySecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\x06se\ deEntrySecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\x06secret\"\x20\
cret\"\x20\n\x0cThpQrCodeTag\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03t\ \n\x0cThpQrCodeTag\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\")\n\
ag\")\n\x0fThpQrCodeSecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\x06\ \x0fThpQrCodeSecret\x12\x16\n\x06secret\x18\x01\x20\x01(\x0cR\x06secret\
secret\"!\n\rThpNfcTagHost\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\ \"!\n\rThpNfcTagHost\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\"#\n\
\"#\n\x0fThpNfcTagTrezor\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\"\ \x0fThpNfcTagTrezor\x12\x10\n\x03tag\x18\x01\x20\x01(\x0cR\x03tag\"f\n\
f\n\x14ThpCredentialRequest\x12,\n\x12host_static_pubkey\x18\x01\x20\x01\ \x14ThpCredentialRequest\x12,\n\x12host_static_pubkey\x18\x01\x20\x01(\
(\x0cR\x10hostStaticPubkey\x12\x20\n\x0bautoconnect\x18\x02\x20\x01(\x08\ \x0cR\x10hostStaticPubkey\x12\x20\n\x0bautoconnect\x18\x02\x20\x01(\x08R\
R\x0bautoconnect\"i\n\x15ThpCredentialResponse\x120\n\x14trezor_static_p\ \x0bautoconnect\"i\n\x15ThpCredentialResponse\x120\n\x14trezor_static_pu\
ubkey\x18\x01\x20\x01(\x0cR\x12trezorStaticPubkey\x12\x1e\n\ncredential\ bkey\x18\x01\x20\x01(\x0cR\x12trezorStaticPubkey\x12\x1e\n\ncredential\
\x18\x02\x20\x01(\x0cR\ncredential\"\x0f\n\rThpEndRequest\"\x10\n\x0eThp\ \x18\x02\x20\x01(\x0cR\ncredential\"\x0f\n\rThpEndRequest\"\x10\n\x0eThp\
EndResponse\"\\\n\x15ThpCredentialMetadata\x12\x1b\n\thost_name\x18\x01\ EndResponse\"\\\n\x15ThpCredentialMetadata\x12\x1b\n\thost_name\x18\x01\
\x20\x01(\tR\x08hostName\x12\x20\n\x0bautoconnect\x18\x02\x20\x01(\x08R\ \x20\x01(\tR\x08hostName\x12\x20\n\x0bautoconnect\x18\x02\x20\x01(\x08R\

View File

@ -5,6 +5,7 @@ import typing as t
import pytest import pytest
import typing_extensions as tx import typing_extensions as tx
from trezorlib import protobuf
from trezorlib.client import ProtocolV2 from trezorlib.client import ProtocolV2
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import ( from trezorlib.messages import (
@ -12,7 +13,9 @@ from trezorlib.messages import (
ButtonRequest, ButtonRequest,
ThpCodeEntryChallenge, ThpCodeEntryChallenge,
ThpCodeEntryCommitment, ThpCodeEntryCommitment,
ThpCodeEntryCpaceHostTag,
ThpCodeEntryCpaceTrezor, ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpEndRequest, ThpEndRequest,
ThpEndResponse, ThpEndResponse,
ThpPairingMethod, ThpPairingMethod,
@ -29,11 +32,26 @@ from trezorlib.transport.thp.protocol_v2 import MANAGEMENT_SESSION_ID, _hkdf
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
P = tx.ParamSpec("P") P = tx.ParamSpec("P")
MT = t.TypeVar("MT", bound=protobuf.MessageType)
pytestmark = [pytest.mark.protocol("protocol_v2")] pytestmark = [pytest.mark.protocol("protocol_v2")]
protocol: ProtocolV2
def _prepare_protocol(client: Client):
global protocol
protocol = client.protocol
protocol.sync_bit_send = 0
protocol.sync_bit_receive = 0
def test_allocate_channel(client: Client) -> None: def test_allocate_channel(client: Client) -> None:
protocol: ProtocolV2 = client.protocol global protocol
_prepare_protocol(client)
# protocol: ProtocolV2 = client.protocol
nonce = b"\x1A\x2B\x3B\x4A\x5C\x6D\x7E\x8F" nonce = b"\x1A\x2B\x3B\x4A\x5C\x6D\x7E\x8F"
# Use valid nonce # Use valid nonce
@ -50,10 +68,10 @@ def test_allocate_channel(client: Client) -> None:
def test_handshake(client: Client) -> None: def test_handshake(client: Client) -> None:
protocol: ProtocolV2 = client.protocol global protocol
_prepare_protocol(client)
# protocol: ProtocolV2 = client.protocol
protocol.sync_bit_send = 0
protocol.sync_bit_receive = 0
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
@ -88,10 +106,27 @@ def test_handshake(client: Client) -> None:
assert noise_tag is not None assert noise_tag is not None
def _send_message(
message: MT,
session_id: int = MANAGEMENT_SESSION_ID,
):
global protocol
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(session_id, message_type, message_data)
protocol._read_ack()
def _read_message(message_type: type[MT]) -> MT:
global protocol
_, msg_type, msg_data = protocol.read_and_decrypt()
msg = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def test_pairing_qr_code(client: Client) -> None: def test_pairing_qr_code(client: Client) -> None:
protocol: ProtocolV2 = client.protocol global protocol
protocol.sync_bit_send = 0 _prepare_protocol(client)
protocol.sync_bit_receive = 0
# Generate ephemeral keys # Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
@ -101,96 +136,54 @@ def test_pairing_qr_code(client: Client) -> None:
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
# Send StartPairingReqest message _send_message(ThpPairingRequest())
message = ThpPairingRequest()
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) _read_message(ButtonRequest)
# Read ACK _send_message(ButtonAck())
protocol._read_ack()
# Read button request
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)
# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
client.debug.press_yes() client.debug.press_yes()
# Read PairingRequestApproved _read_message(ThpPairingRequestApproved)
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpPairingRequestApproved) _send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode))
message = ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) _read_message(ThpPairingPreparationsFinished)
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()
# Read ThpPairingPreparationsFinished
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpPairingPreparationsFinished)
# QR Code shown # QR Code shown
# Read button request _read_message(ButtonRequest)
_, msg_type, msg_data = protocol.read_and_decrypt() _send_message(ButtonAck())
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)
# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
# Read code from "Trezor's display" using debuglink
state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big")) state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big"))
code = state.thp_pairing_code_qr_code
# Compute tag for response
sha_ctx = hashlib.sha256(protocol.handshake_hash) sha_ctx = hashlib.sha256(protocol.handshake_hash)
sha_ctx.update(state.thp_pairing_code_qr_code) sha_ctx.update(state.thp_pairing_code_qr_code)
tag = sha_ctx.digest() tag = sha_ctx.digest()
message_type, message_data = protocol.mapping.encode(ThpQrCodeTag(tag=tag)) _send_message(ThpQrCodeTag(tag=tag))
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack() secret_msg = _read_message(ThpQrCodeSecret)
# Read ThpQrCodeSecret # Check that the `code` was derived from the revealed secret
_, msg_type, msg_data = protocol.read_and_decrypt() sha_ctx = hashlib.sha256(ThpPairingMethod.QrCode.to_bytes(1, "big"))
maaa = protocol.mapping.decode(msg_type, msg_data) sha_ctx.update(protocol.handshake_hash)
assert isinstance(maaa, ThpQrCodeSecret) sha_ctx.update(secret_msg.secret)
computed_code = sha_ctx.digest()[:16]
assert code == computed_code
message = ThpEndRequest() _send_message(ThpEndRequest())
message_type, message_data = protocol.mapping.encode(message) _read_message(ThpEndResponse)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()
# Read ThpEndResponse
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpEndResponse)
protocol._has_valid_channel = True protocol._has_valid_channel = True
@pytest.mark.skip("Cpace is not implemented yet") @pytest.mark.skip("Cpace is not implemented yet")
def test_pairing_code_entry(client: Client) -> None: def test_pairing_code_entry(client: Client) -> None:
protocol: ProtocolV2 = client.protocol global protocol
protocol.sync_bit_send = 0 _prepare_protocol(client)
protocol.sync_bit_receive = 0
# Generate ephemeral keys # Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
@ -200,101 +193,57 @@ def test_pairing_code_entry(client: Client) -> None:
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey) protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
# Send StartPairingReqest message _send_message(ThpPairingRequest())
message = ThpPairingRequest()
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) _read_message(ButtonRequest)
# Read ACK _send_message(ButtonAck())
protocol._read_ack()
# Read button request
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)
# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
client.debug.press_yes() client.debug.press_yes()
# Read PairingRequestApproved _read_message(ThpPairingRequestApproved)
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpPairingRequestApproved) _send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry))
message = ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) commitment_msg = _read_message(ThpCodeEntryCommitment)
message_type, message_data = protocol.mapping.encode(message) commitment = commitment_msg.commitment
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
protocol._read_ack()
# Read ThpCodeEntryCommitment
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpCodeEntryCommitment)
challenge = b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xAA\xBB\xCC\xDD\xEE\xFF" challenge = b"\x00\x11\x22\x33\x44\x55\x66\x77\x88\x99\xAA\xBB\xCC\xDD\xEE\xFF"
message = ThpCodeEntryChallenge(challenge=challenge) _send_message(ThpCodeEntryChallenge(challenge=challenge))
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) cpace_trezor = _read_message(ThpCodeEntryCpaceTrezor)
# Read ACK cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key
protocol._read_ack()
# Read ThpCodeEntryCpaceTrezor
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpCodeEntryCpaceTrezor)
_ = maaa.cpace_trezor_public_key
# Code Entry code shown # Code Entry code shown
# Read button request _read_message(ButtonRequest)
_, msg_type, msg_data = protocol.read_and_decrypt() _send_message(ButtonAck())
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ButtonRequest)
# Send button ACK
message = ButtonAck()
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack()
state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big")) state = client.debug.state(thp_channel_id=protocol.channel_id.to_bytes(2, "big"))
code = state.thp_pairing_code_entry_code
sha_ctx = hashlib.sha256(protocol.handshake_hash) # TODO fix missing CPACE
sha_ctx.update(state.thp_pairing_code_entry_code) cpace_shared_secret = b"\x01"
sha_ctx = hashlib.sha256(cpace_shared_secret)
tag = sha_ctx.digest() tag = sha_ctx.digest()
message_type, message_data = protocol.mapping.encode(ThpQrCodeTag(tag=tag)) cpace_host_public_key = cpace_trezor_public_key
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
protocol._read_ack() _send_message(
ThpCodeEntryCpaceHostTag(
cpace_host_public_key=cpace_host_public_key,
tag=tag,
)
)
# Read ThpQrCodeSecret secret_msg = _read_message(ThpCodeEntrySecret)
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpQrCodeSecret)
message = ThpEndRequest() # Check `commitment` and `code`
message_type, message_data = protocol.mapping.encode(message) sha_ctx = hashlib.sha256(secret_msg.secret)
computed_commitment = sha_ctx.digest()
assert commitment == computed_commitment
assert code == b"" # TODO implement
protocol._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) _send_message(ThpEndRequest())
# Read ACK _read_message(ThpEndResponse)
protocol._read_ack()
# Read ThpEndResponse
_, msg_type, msg_data = protocol.read_and_decrypt()
maaa = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, ThpEndResponse)
protocol._has_valid_channel = True protocol._has_valid_channel = True