1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-31 09:50:58 +00:00

refactor(python): improve protocolV2 and related tests

[no changelog]
This commit is contained in:
M1nd3r 2025-01-30 19:41:44 +01:00
parent 9b16b9b7a7
commit f8f2bfa535
2 changed files with 98 additions and 163 deletions

View File

@ -11,7 +11,7 @@ from enum import IntEnum
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages
from ... import exceptions, messages, protobuf
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, curve25519, thp_io
@ -28,6 +28,7 @@ MANAGEMENT_SESSION_ID: int = 0
if t.TYPE_CHECKING:
from ...debuglink import DebugLink
MT = t.TypeVar("MT", bound=protobuf.MessageType)
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
@ -135,20 +136,31 @@ class ProtocolV2(ProtocolAndChannel):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features
def _send_message(
self,
message: protobuf.MessageType,
session_id: int = MANAGEMENT_SESSION_ID,
):
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(session_id, message_type, message_data)
self._read_ack()
def _read_message(self, message_type: type[MT]) -> MT:
_, msg_type, msg_data = self.read_and_decrypt()
msg = self.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
self._reset_sync_bits()
self._do_channel_allocation()
self._do_handshake()
self._do_pairing(helper_debug)
def _reset_sync_bits(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._do_channel_allocation()
self._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
self._do_pairing(helper_debug)
def _do_channel_allocation(self) -> None:
channel_allocation_nonce = os.urandom(8)
self._send_channel_allocation_request(channel_allocation_nonce)
@ -176,9 +188,10 @@ class ProtocolV2(ProtocolAndChannel):
device_properties = payload[10:]
return (channel_id, device_properties)
def _do_handshake(
self, host_ephemeral_privkey: bytes, host_ephemeral_pubkey: bytes
):
def _do_handshake(self):
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._send_handshake_init_request(host_ephemeral_pubkey)
self._read_ack()
init_response = self._read_handshake_init_response()
@ -309,49 +322,21 @@ class ProtocolV2(ProtocolAndChannel):
self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None):
# Send StartPairingReqest message
message = messages.ThpPairingRequest()
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
self._read_ack()
# Read button request
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ButtonRequest)
# Send button ACK
message = messages.ButtonAck()
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
self._read_ack()
self._send_message(messages.ThpPairingRequest())
self._read_message(messages.ButtonRequest)
self._send_message(messages.ButtonAck())
if helper_debug is not None:
helper_debug.press_yes()
# Read PairingRequestApproved
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpPairingRequestApproved)
message = messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
self._read_message(messages.ThpPairingRequestApproved)
self._send_message(
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
)
)
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
self._read_ack()
# Read ThpEndResponse
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpEndResponse)
self._read_message(messages.ThpEndResponse)
self._has_valid_channel = True

View File

@ -31,7 +31,7 @@ from trezorlib.messages import (
)
from trezorlib.transport.thp import curve25519
from trezorlib.transport.thp.cpace import Cpace
from trezorlib.transport.thp.protocol_v2 import MANAGEMENT_SESSION_ID, _hkdf
from trezorlib.transport.thp.protocol_v2 import _hkdf
if t.TYPE_CHECKING:
P = tx.ParamSpec("P")
@ -41,22 +41,24 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType)
pytestmark = [pytest.mark.protocol("protocol_v2")]
protocol: ProtocolV2
def _prepare_protocol(client: Client):
global protocol
def _prepare_protocol(client: Client) -> ProtocolV2:
protocol = client.protocol
protocol.sync_bit_send = 0
protocol.sync_bit_receive = 0
assert isinstance(protocol, ProtocolV2)
protocol._reset_sync_bits()
return protocol
def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2:
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake()
return protocol
def test_allocate_channel(client: Client) -> None:
global protocol
_prepare_protocol(client)
protocol = _prepare_protocol(client)
# protocol: ProtocolV2 = client.protocol
nonce = b"\x1A\x2B\x3B\x4A\x5C\x6D\x7E\x8F"
nonce = random.randbytes(8)
# Use valid nonce
protocol._send_channel_allocation_request(nonce)
@ -72,9 +74,7 @@ def test_allocate_channel(client: Client) -> None:
def test_handshake(client: Client) -> None:
global protocol
_prepare_protocol(client)
# protocol: ProtocolV2 = client.protocol
protocol = _prepare_protocol(client)
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
@ -110,53 +110,24 @@ def test_handshake(client: Client) -> None:
assert noise_tag is not None
def _send_message(
message: MT,
session_id: int = MANAGEMENT_SESSION_ID,
):
global protocol
message_type, message_data = protocol.mapping.encode(message)
protocol._encrypt_and_write(session_id, message_type, message_data)
protocol._read_ack()
def _read_message(message_type: type[MT]) -> MT:
global protocol
_, msg_type, msg_data = protocol.read_and_decrypt()
msg = protocol.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def test_pairing_qr_code(client: Client) -> None:
global protocol
_prepare_protocol(client)
protocol = _prepare_protocol_for_pairing(client)
# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
protocol._do_channel_allocation()
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
_send_message(ThpPairingRequest())
_read_message(ButtonRequest)
_send_message(ButtonAck())
protocol._send_message(ThpPairingRequest())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
client.debug.press_yes()
_read_message(ThpPairingRequestApproved)
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode))
_read_message(ThpPairingPreparationsFinished)
protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
)
protocol._read_message(ThpPairingPreparationsFinished)
# QR Code shown
_read_message(ButtonRequest)
_send_message(ButtonAck())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
# Read code from "Trezor's display" using debuglink
@ -170,9 +141,9 @@ def test_pairing_qr_code(client: Client) -> None:
sha_ctx.update(code)
tag = sha_ctx.digest()
_send_message(ThpQrCodeTag(tag=tag))
protocol._send_message(ThpQrCodeTag(tag=tag))
secret_msg = _read_message(ThpQrCodeSecret)
secret_msg = protocol._read_message(ThpQrCodeSecret)
# Check that the `code` was derived from the revealed secret
sha_ctx = sha256(ThpPairingMethod.QrCode.to_bytes(1, "big"))
@ -181,48 +152,38 @@ def test_pairing_qr_code(client: Client) -> None:
computed_code = sha_ctx.digest()[:16]
assert code == computed_code
_send_message(ThpEndRequest())
_read_message(ThpEndResponse)
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True
def test_pairing_code_entry(client: Client) -> None:
global protocol
_prepare_protocol(client)
protocol = _prepare_protocol_for_pairing(client)
# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
protocol._do_channel_allocation()
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
_send_message(ThpPairingRequest())
_read_message(ButtonRequest)
_send_message(ButtonAck())
protocol._send_message(ThpPairingRequest())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
client.debug.press_yes()
_read_message(ThpPairingRequestApproved)
protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
)
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry))
commitment_msg = _read_message(ThpCodeEntryCommitment)
commitment_msg = protocol._read_message(ThpCodeEntryCommitment)
commitment = commitment_msg.commitment
challenge = random.randbytes(16)
_send_message(ThpCodeEntryChallenge(challenge=challenge))
protocol._send_message(ThpCodeEntryChallenge(challenge=challenge))
cpace_trezor = _read_message(ThpCodeEntryCpaceTrezor)
cpace_trezor = protocol._read_message(ThpCodeEntryCpaceTrezor)
cpace_trezor_public_key = cpace_trezor.cpace_trezor_public_key
# Code Entry code shown
_read_message(ButtonRequest)
_send_message(ButtonAck())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
pairing_info = client.debug.pairing_info(
thp_channel_id=protocol.channel_id.to_bytes(2, "big")
@ -235,14 +196,14 @@ def test_pairing_code_entry(client: Client) -> None:
sha_ctx = sha256(cpace.shared_secret)
tag = sha_ctx.digest()
_send_message(
protocol._send_message(
ThpCodeEntryCpaceHostTag(
cpace_host_public_key=cpace.host_public_key,
tag=tag,
)
)
secret_msg = _read_message(ThpCodeEntrySecret)
secret_msg = protocol._read_message(ThpCodeEntrySecret)
# Check `commitment` and `code`
sha_ctx = sha256(secret_msg.secret)
@ -257,41 +218,30 @@ def test_pairing_code_entry(client: Client) -> None:
computed_code = int.from_bytes(code_hash, "big") % 1000000
assert code == computed_code
_send_message(ThpEndRequest())
_read_message(ThpEndResponse)
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True
def test_pairing_nfc(client: Client) -> None:
global protocol
_prepare_protocol(client)
protocol = _prepare_protocol_for_pairing(client)
# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
protocol._do_channel_allocation()
protocol._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
_send_message(ThpPairingRequest())
_read_message(ButtonRequest)
_send_message(ButtonAck())
protocol._send_message(ThpPairingRequest())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
client.debug.press_yes()
_read_message(ThpPairingRequestApproved)
_send_message(ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC))
_read_message(ThpPairingPreparationsFinished)
protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)
)
protocol._read_message(ThpPairingPreparationsFinished)
# NFC screen shown
_read_message(ButtonRequest)
_send_message(ButtonAck())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
nfc_secret_host = random.randbytes(16)
# Read `nfc_secret` and `handshake_hash` from Trezor using debuglink
@ -311,9 +261,9 @@ def test_pairing_nfc(client: Client) -> None:
sha_ctx.update(nfc_secret_trezor)
tag_host = sha_ctx.digest()
_send_message(ThpNfcTagHost(tag=tag_host))
protocol._send_message(ThpNfcTagHost(tag=tag_host))
tag_trezor_msg = _read_message(ThpNfcTagTrezor)
tag_trezor_msg = protocol._read_message(ThpNfcTagTrezor)
# Check that the `code` was derived from the revealed secret
sha_ctx = sha256(ThpPairingMethod.NFC.to_bytes(1, "big"))
@ -322,7 +272,7 @@ def test_pairing_nfc(client: Client) -> None:
computed_tag = sha_ctx.digest()
assert tag_trezor_msg.tag == computed_tag
_send_message(ThpEndRequest())
_read_message(ThpEndResponse)
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True