1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-27 07:40:59 +00:00

refactor(python): improve structure of protocol_v2 channel establishment

[no changelog]
This commit is contained in:
M1nd3r 2025-01-25 22:47:30 +01:00
parent 6b3fa22c85
commit e39de541bd

View File

@ -132,40 +132,78 @@ class ProtocolV2(ProtocolAndChannel):
def _establish_new_channel(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
# Send channel allocation request
channel_id_request_nonce = os.urandom(8)
# Generate ephemeral keys
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._do_channel_allocation()
self._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
self._do_pairing()
def _do_channel_allocation(self) -> None:
channel_allocation_nonce = os.urandom(8)
self._send_channel_allocation_request(channel_allocation_nonce)
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
self.channel_id = cid
self.device_properties = dp
def _send_channel_allocation_request(self, nonce: bytes):
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
MessageHeader.get_channel_allocation_request_header(12),
channel_id_request_nonce,
nonce,
)
# Read channel allocation response
def _read_channel_allocation_response(
self, expected_nonce: bytes
) -> tuple[int, bytes]:
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, channel_id_request_nonce
header, payload, expected_nonce
):
# TODO raise exception here, I guess
raise Exception("Invalid channel allocation response.")
self.channel_id = int.from_bytes(payload[8:10], "big")
self.device_properties = payload[10:]
channel_id = int.from_bytes(payload[8:10], "big")
device_properties = payload[10:]
return (channel_id, device_properties)
# Send handshake init request
def _do_handshake(
self, host_ephemeral_privkey: bytes, host_ephemeral_pubkey: bytes
):
self._send_handshake_init_request(host_ephemeral_pubkey)
self._read_ack()
init_response = self._read_handshake_init_response()
# assert len(init_response) == 96
trezor_ephemeral_pubkey = init_response[:32]
encrypted_trezor_static_pubkey = init_response[32:80]
noise_tag = init_response[80:96]
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# TODO check noise_tag is valid
ck = self._send_handshake_completion_request(
host_ephemeral_pubkey,
host_ephemeral_privkey,
trezor_ephemeral_pubkey,
encrypted_trezor_static_pubkey,
)
self._read_ack()
self._read_handshake_completion_response()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
# Read handshake init response
def _read_handshake_init_response(self) -> bytes:
header, payload = self._read_until_valid_crc_check()
self._send_ack_0()
@ -180,15 +218,16 @@ class ProtocolV2(ProtocolAndChannel):
"Received message is not a valid handshake init response message",
err=True,
)
return payload
trezor_ephemeral_pubkey = payload[:32]
encrypted_trezor_static_pubkey = payload[32:80]
noise_tag = payload[80:96]
# TODO check noise tag
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# Prepare and send handshake completion request
def _send_handshake_completion_request(
self,
host_ephemeral_pubkey: bytes,
host_ephemeral_privkey: bytes,
trezor_ephemeral_pubkey: bytes,
encrypted_trezor_static_pubkey: bytes,
credential: bytes | None = None,
) -> bytes:
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
@ -232,7 +271,8 @@ class ProtocolV2(ProtocolAndChannel):
messages.ThpHandshakeCompletionReqNoisePayload(
pairing_methods=[
messages.ThpPairingMethod.NoMethod,
]
],
host_pairing_credential=credential,
)
)
@ -252,12 +292,9 @@ class ProtocolV2(ProtocolAndChannel):
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
return ck
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
def _read_handshake_completion_response(self) -> None:
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
@ -267,10 +304,7 @@ class ProtocolV2(ProtocolAndChannel):
)
self._send_ack_1()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
def _do_pairing(self):
# Send StartPairingReqest message
message = messages.ThpStartPairingRequest()
message_type, message_data = self.mapping.encode(message)
@ -278,17 +312,20 @@ class ProtocolV2(ProtocolAndChannel):
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
# Read ACK
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
self._read_ack()
# Read
# Read ThpEndResponse
_, msg_type, msg_data = self.read_and_decrypt()
maaa = self.mapping.decode(msg_type, msg_data)
assert isinstance(maaa, messages.ThpEndResponse)
self._has_valid_channel = True
def _read_ack(self):
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
def _send_ack_0(self):
LOG.debug("sending ack 0")
header = MessageHeader(0x20, self.channel_id, 4)