1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-31 18:00:58 +00:00

refactor(python): improve protocolV2 and related tests

[no changelog]
This commit is contained in:
M1nd3r 2025-01-30 19:41:44 +01:00
parent 9b16b9b7a7
commit f8f2bfa535
2 changed files with 98 additions and 163 deletions

View File

@ -11,7 +11,7 @@ from enum import IntEnum
import click import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages from ... import exceptions, messages, protobuf
from ...mapping import ProtobufMapping from ...mapping import ProtobufMapping
from .. import Transport from .. import Transport
from ..thp import checksum, curve25519, thp_io from ..thp import checksum, curve25519, thp_io
@ -28,6 +28,7 @@ MANAGEMENT_SESSION_ID: int = 0
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ...debuglink import DebugLink from ...debuglink import DebugLink
MT = t.TypeVar("MT", bound=protobuf.MessageType)
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
@ -135,20 +136,31 @@ class ProtocolV2(ProtocolAndChannel):
raise exceptions.TrezorException("Unexpected response to GetFeatures") raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features self._features = features
def _send_message(
self,
message: protobuf.MessageType,
session_id: int = MANAGEMENT_SESSION_ID,
):
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(session_id, message_type, message_data)
self._read_ack()
def _read_message(self, message_type: type[MT]) -> MT:
_, msg_type, msg_data = self.read_and_decrypt()
msg = self.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None: def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
self._reset_sync_bits()
self._do_channel_allocation()
self._do_handshake()
self._do_pairing(helper_debug)
def _reset_sync_bits(self) -> None:
self.sync_bit_send = 0 self.sync_bit_send = 0
self.sync_bit_receive = 0 self.sync_bit_receive = 0
# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._do_channel_allocation()
self._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
self._do_pairing(helper_debug)
def _do_channel_allocation(self) -> None: def _do_channel_allocation(self) -> None:
channel_allocation_nonce = os.urandom(8) channel_allocation_nonce = os.urandom(8)
self._send_channel_allocation_request(channel_allocation_nonce) self._send_channel_allocation_request(channel_allocation_nonce)
@ -176,9 +188,10 @@ class ProtocolV2(ProtocolAndChannel):
device_properties = payload[10:] device_properties = payload[10:]
return (channel_id, device_properties) return (channel_id, device_properties)
def _do_handshake( def _do_handshake(self):
self, host_ephemeral_privkey: bytes, host_ephemeral_pubkey: bytes host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
): host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._send_handshake_init_request(host_ephemeral_pubkey) self._send_handshake_init_request(host_ephemeral_pubkey)
self._read_ack() self._read_ack()
init_response = self._read_handshake_init_response() init_response = self._read_handshake_init_response()
@ -309,49 +322,21 @@ class ProtocolV2(ProtocolAndChannel):
self._send_ack_1() self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None): def _do_pairing(self, helper_debug: DebugLink | None):
# Send StartPairingReqest message
message = messages.ThpPairingRequest()
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) self._send_message(messages.ThpPairingRequest())
self._read_message(messages.ButtonRequest)
# Read ACK self._send_message(messages.ButtonAck())
self._read_ack()
# Read button request
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ButtonRequest)
# Send button ACK
message = messages.ButtonAck()
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
self._read_ack()
if helper_debug is not None: if helper_debug is not None:
helper_debug.press_yes() helper_debug.press_yes()
# Read PairingRequestApproved self._read_message(messages.ThpPairingRequestApproved)
_, msg_type, msg_data = self.read_and_decrypt() self._send_message(
maaa = self.mapping.decode(msg_type, msg_data) messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
assert isinstance(maaa, messages.ThpPairingRequestApproved) )
message = messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
) )
message_type, message_data = self.mapping.encode(message) self._read_message(messages.ThpEndResponse)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
self._read_ack()
# Read ThpEndResponse
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpEndResponse)
self._has_valid_channel = True self._has_valid_channel = True

View File

@ -31,7 +31,7 @@ from trezorlib.messages import (
) )
from trezorlib.transport.thp import curve25519 from trezorlib.transport.thp import curve25519
from trezorlib.transport.thp.cpace import Cpace from trezorlib.transport.thp.cpace import Cpace
from trezorlib.transport.thp.protocol_v2 import MANAGEMENT_SESSION_ID, _hkdf from trezorlib.transport.thp.protocol_v2 import _hkdf
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
P = tx.ParamSpec("P") P = tx.ParamSpec("P")
@ -41,22 +41,24 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType)
pytestmark = [pytest.mark.protocol("protocol_v2")] pytestmark = [pytest.mark.protocol("protocol_v2")]
protocol: ProtocolV2 def _prepare_protocol(client: Client) -> ProtocolV2:
def _prepare_protocol(client: Client):
global protocol
protocol = client.protocol protocol = client.protocol
protocol.sync_bit_send = 0 assert isinstance(protocol, ProtocolV2)
protocol.sync_bit_receive = 0 protocol._reset_sync_bits()
return protocol
def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2:
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake()
return protocol
def test_allocate_channel(client: Client) -> None: def test_allocate_channel(client: Client) -> None:
global protocol protocol = _prepare_protocol(client)
_prepare_protocol(client)
# protocol: ProtocolV2 = client.protocol nonce = random.randbytes(8)
nonce = b"\x1A\x2B\x3B\x4A\x5C\x6D\x7E\x8F"
# Use valid nonce # Use valid nonce
protocol._send_channel_allocation_request(nonce) protocol._send_channel_allocation_request(nonce)
@ -72,9 +74,7 @@ def test_allocate_channel(client: Client) -> None:
def test_handshake(client: Client) -> None: def test_handshake(client: Client) -> None:
global protocol protocol = _prepare_protocol(client)
_prepare_protocol(client)
# protocol: ProtocolV2 = client.protocol
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
@ -110,53 +110,24 @@ def test_handshake(client: Client) -> None:
assert noise_tag is not None assert noise_tag is not None
def _send_message(
message: MT,
session_id: int = MANAGEMENT_SESSION_ID,
):
global protocol
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(session_id, message_type, message_data)
protocol._read_ack()
def _read_message(message_type: type[MT]) -> MT:
global protocol
_, msg_type, msg_data = protocol.read_and_decrypt()
msg = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def test_pairing_qr_code(client: Client) -> None: def test_pairing_qr_code(client: Client) -> None:
global protocol protocol = _prepare_protocol_for_pairing(client)
_prepare_protocol(client)
# Generate ephemeral keys protocol._send_message(ThpPairingRequest())
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) protocol._read_message(ButtonRequest)
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) protocol._send_message(ButtonAck())
protocol._do_channel_allocation()
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
_send_message(ThpPairingRequest())
_read_message(ButtonRequest)
_send_message(ButtonAck())
client.debug.press_yes() client.debug.press_yes()
_read_message(ThpPairingRequestApproved) protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)) ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
)
_read_message(ThpPairingPreparationsFinished) protocol._read_message(ThpPairingPreparationsFinished)
# QR Code shown # QR Code shown
_read_message(ButtonRequest) protocol._read_message(ButtonRequest)
_send_message(ButtonAck()) protocol._send_message(ButtonAck())
# Read code from "Trezor's display" using debuglink # Read code from "Trezor's display" using debuglink
@ -170,9 +141,9 @@ def test_pairing_qr_code(client: Client) -> None:
sha_ctx.update(code) sha_ctx.update(code)
tag = sha_ctx.digest() tag = sha_ctx.digest()
_send_message(ThpQrCodeTag(tag=tag)) protocol._send_message(ThpQrCodeTag(tag=tag))
secret_msg = _read_message(ThpQrCodeSecret) secret_msg = protocol._read_message(ThpQrCodeSecret)
# Check that the `code` was derived from the revealed secret # Check that the `code` was derived from the revealed secret
sha_ctx = sha256(ThpPairingMethod.QrCode.to_bytes(1, "big")) sha_ctx = sha256(ThpPairingMethod.QrCode.to_bytes(1, "big"))
@ -181,48 +152,38 @@ def test_pairing_qr_code(client: Client) -> None:
computed_code = sha_ctx.digest()[:16] computed_code = sha_ctx.digest()[:16]
assert code == computed_code assert code == computed_code
_send_message(ThpEndRequest()) protocol._send_message(ThpEndRequest())
_read_message(ThpEndResponse) protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True protocol._has_valid_channel = True
def test_pairing_code_entry(client: Client) -> None: def test_pairing_code_entry(client: Client) -> None:
global protocol protocol = _prepare_protocol_for_pairing(client)
_prepare_protocol(client)
# Generate ephemeral keys protocol._send_message(ThpPairingRequest())
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) protocol._read_message(ButtonRequest)
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) protocol._send_message(ButtonAck())
protocol._do_channel_allocation()
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
_send_message(ThpPairingRequest())
_read_message(ButtonRequest)
_send_message(ButtonAck())
client.debug.press_yes() client.debug.press_yes()
_read_message(ThpPairingRequestApproved) protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
)
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)) commitment_msg = protocol._read_message(ThpCodeEntryCommitment)
commitment_msg = _read_message(ThpCodeEntryCommitment)
commitment = commitment_msg.commitment commitment = commitment_msg.commitment
challenge = random.randbytes(16) challenge = random.randbytes(16)
_send_message(ThpCodeEntryChallenge(challenge=challenge)) protocol._send_message(ThpCodeEntryChallenge(challenge=challenge))
cpace_trezor = _read_message(ThpCodeEntryCpaceTrezor) cpace_trezor = protocol._read_message(ThpCodeEntryCpaceTrezor)
cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key
# Code Entry code shown # Code Entry code shown
_read_message(ButtonRequest) protocol._read_message(ButtonRequest)
_send_message(ButtonAck()) protocol._send_message(ButtonAck())
pairing_info = client.debug.pairing_info( pairing_info = client.debug.pairing_info(
thp_channel_id=protocol.channel_id.to_bytes(2, "big") thp_channel_id=protocol.channel_id.to_bytes(2, "big")
@ -235,14 +196,14 @@ def test_pairing_code_entry(client: Client) -> None:
sha_ctx = sha256(cpace.shared_secret) sha_ctx = sha256(cpace.shared_secret)
tag = sha_ctx.digest() tag = sha_ctx.digest()
_send_message( protocol._send_message(
ThpCodeEntryCpaceHostTag( ThpCodeEntryCpaceHostTag(
cpace_host_public_key=cpace.host_public_key, cpace_host_public_key=cpace.host_public_key,
tag=tag, tag=tag,
) )
) )
secret_msg = _read_message(ThpCodeEntrySecret) secret_msg = protocol._read_message(ThpCodeEntrySecret)
# Check `commitment` and `code` # Check `commitment` and `code`
sha_ctx = sha256(secret_msg.secret) sha_ctx = sha256(secret_msg.secret)
@ -257,41 +218,30 @@ def test_pairing_code_entry(client: Client) -> None:
computed_code = int.from_bytes(code_hash, "big") % 1000000 computed_code = int.from_bytes(code_hash, "big") % 1000000
assert code == computed_code assert code == computed_code
_send_message(ThpEndRequest()) protocol._send_message(ThpEndRequest())
_read_message(ThpEndResponse) protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True protocol._has_valid_channel = True
def test_pairing_nfc(client: Client) -> None: def test_pairing_nfc(client: Client) -> None:
global protocol protocol = _prepare_protocol_for_pairing(client)
_prepare_protocol(client)
# Generate ephemeral keys protocol._send_message(ThpPairingRequest())
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) protocol._read_message(ButtonRequest)
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) protocol._send_message(ButtonAck())
protocol._do_channel_allocation()
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
_send_message(ThpPairingRequest())
_read_message(ButtonRequest)
_send_message(ButtonAck())
client.debug.press_yes() client.debug.press_yes()
_read_message(ThpPairingRequestApproved) protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)) ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)
)
_read_message(ThpPairingPreparationsFinished) protocol._read_message(ThpPairingPreparationsFinished)
# NFC screen shown # NFC screen shown
_read_message(ButtonRequest) protocol._read_message(ButtonRequest)
_send_message(ButtonAck()) protocol._send_message(ButtonAck())
nfc_secret_host = random.randbytes(16) nfc_secret_host = random.randbytes(16)
# Read `nfc_secret` and `handshake_hash` from Trezor using debuglink # Read `nfc_secret` and `handshake_hash` from Trezor using debuglink
@ -311,9 +261,9 @@ def test_pairing_nfc(client: Client) -> None:
sha_ctx.update(nfc_secret_trezor) sha_ctx.update(nfc_secret_trezor)
tag_host = sha_ctx.digest() tag_host = sha_ctx.digest()
_send_message(ThpNfcTagHost(tag=tag_host)) protocol._send_message(ThpNfcTagHost(tag=tag_host))
tag_trezor_msg = _read_message(ThpNfcTagTrezor) tag_trezor_msg = protocol._read_message(ThpNfcTagTrezor)
# Check that the `code` was derived from the revealed secret # Check that the `code` was derived from the revealed secret
sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big")) sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big"))
@ -322,7 +272,7 @@ def test_pairing_nfc(client: Client) -> None:
computed_tag = sha_ctx.digest() computed_tag = sha_ctx.digest()
assert tag_trezor_msg.tag == computed_tag assert tag_trezor_msg.tag == computed_tag
_send_message(ThpEndRequest()) protocol._send_message(ThpEndRequest())
_read_message(ThpEndResponse) protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True protocol._has_valid_channel = True