1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-08 06:20:56 +00:00

chore(trezorlib): clean trezolib THP, part 1

[no changelog]
This commit is contained in:
M1nd3r 2024-10-14 17:52:58 +02:00
parent 386595cab2
commit 594ec26fa7
3 changed files with 67 additions and 42 deletions

View File

@ -8,6 +8,7 @@ import typing as t
from binascii import hexlify from binascii import hexlify
from enum import IntEnum from enum import IntEnum
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages from ... import exceptions, messages
@ -15,7 +16,7 @@ from ...mapping import ProtobufMapping
from .. import Transport from .. import Transport
from ..thp import checksum, curve25519, thp_io from ..thp import checksum, curve25519, thp_io
from ..thp.checksum import CHECKSUM_LENGTH from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.packet_header import PacketHeader from ..thp.message_header import MessageHeader
from . import channel_database, control_byte from . import channel_database, control_byte
from .channel_data import ChannelData from .channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel from .protocol_and_channel import ProtocolAndChannel
@ -97,7 +98,7 @@ class ProtocolV2(ProtocolAndChannel):
def read(self, session_id: int) -> t.Any: def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt() sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id: if sid != session_id:
raise Exception("Received messsage on different session.") raise Exception("Received messsage on a different session.")
channel_database.save_channel(self) channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data) return self.mapping.decode(msg_type, msg_data)
@ -133,7 +134,7 @@ class ProtocolV2(ProtocolAndChannel):
channel_id_request_nonce = os.urandom(8) channel_id_request_nonce = os.urandom(8)
thp_io.write_payload_to_wire_and_add_checksum( thp_io.write_payload_to_wire_and_add_checksum(
self.transport, self.transport,
PacketHeader.get_channel_allocation_request_header(12), MessageHeader.get_channel_allocation_request_header(12),
channel_id_request_nonce, channel_id_request_nonce,
) )
@ -142,13 +143,14 @@ class ProtocolV2(ProtocolAndChannel):
if not self._is_valid_channel_allocation_response( if not self._is_valid_channel_allocation_response(
header, payload, channel_id_request_nonce header, payload, channel_id_request_nonce
): ):
print("TODO raise exception here, I guess") # TODO raise exception here, I guess
raise Exception("Invalid channel allocation response.")
self.channel_id = int.from_bytes(payload[8:10], "big") self.channel_id = int.from_bytes(payload[8:10], "big")
self.device_properties = payload[10:] self.device_properties = payload[10:]
# Send handshake init request # Send handshake init request
ha_init_req_header = PacketHeader(0, self.channel_id, 36) ha_init_req_header = MessageHeader(0, self.channel_id, 36)
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
@ -159,14 +161,17 @@ class ProtocolV2(ProtocolAndChannel):
# Read ACK # Read ACK
header, payload = self._read_until_valid_crc_check() header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0: if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ") click.echo("Received message is not a valid ACK", err=True)
# Read handshake init response # Read handshake init response
header, payload = self._read_until_valid_crc_check() header, payload = self._read_until_valid_crc_check()
self._send_ack_0() self._send_ack_0()
if not header.is_handshake_init_response(): if not header.is_handshake_init_response():
print("Received message is not a valid handshake init response message") click.echo(
"Received message is not a valid handshake init response message",
err=True,
)
trezor_ephemeral_pubkey = payload[:32] trezor_ephemeral_pubkey = payload[:32]
encrypted_trezor_static_pubkey = payload[32:80] encrypted_trezor_static_pubkey = payload[32:80]
@ -193,7 +198,9 @@ class ProtocolV2(ProtocolAndChannel):
IV_1, encrypted_trezor_static_pubkey, h IV_1, encrypted_trezor_static_pubkey, h
) )
except Exception as e: except Exception as e:
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik click.echo(
f"Exception of type{type(e)}", err=True
) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey) h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf( ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
@ -225,7 +232,7 @@ class ProtocolV2(ProtocolAndChannel):
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload) h = _sha256_of_two(h, encrypted_payload)
ha_completion_req_header = PacketHeader( ha_completion_req_header = MessageHeader(
0x12, 0x12,
self.channel_id, self.channel_id,
len(encrypted_host_static_pubkey) len(encrypted_host_static_pubkey)
@ -241,12 +248,15 @@ class ProtocolV2(ProtocolAndChannel):
# Read ACK # Read ACK
header, payload = self._read_until_valid_crc_check() header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0: if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ") click.echo("Received message is not a valid ACK", err=True)
# Read handshake completion response, ignore payload as we do not care about the state # Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check() header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response(): if not header.is_handshake_comp_response():
print("Received message is not a valid handshake completion response") click.echo(
"Received message is not a valid handshake completion response",
err=True,
)
self._send_ack_1() self._send_ack_1()
self.key_request, self.key_response = _hkdf(ck, b"") self.key_request, self.key_response = _hkdf(ck, b"")
@ -262,7 +272,7 @@ class ProtocolV2(ProtocolAndChannel):
# Read ACK # Read ACK
header, payload = self._read_until_valid_crc_check() header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0: if not header.is_ack() or len(payload) > 0:
print("Received message is not a valid ACK ") click.echo("Received message is not a valid ACK", err=True)
# Read # Read
_, msg_type, msg_data = self.read_and_decrypt() _, msg_type, msg_data = self.read_and_decrypt()
@ -273,12 +283,12 @@ class ProtocolV2(ProtocolAndChannel):
def _send_ack_0(self): def _send_ack_0(self):
LOG.debug("sending ack 0") LOG.debug("sending ack 0")
header = PacketHeader(0x20, self.channel_id, 4) header = MessageHeader(0x20, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _send_ack_1(self): def _send_ack_1(self):
LOG.debug("sending ack 1") LOG.debug("sending ack 1")
header = PacketHeader(0x28, self.channel_id, 4) header = MessageHeader(0x28, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _encrypt_and_write( def _encrypt_and_write(
@ -301,7 +311,7 @@ class ProtocolV2(ProtocolAndChannel):
nonce = _get_iv_from_nonce(self.nonce_request) nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1 self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"") encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = PacketHeader( header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
) )
@ -314,9 +324,10 @@ class ProtocolV2(ProtocolAndChannel):
if control_byte.is_ack(header.ctrl_byte): if control_byte.is_ack(header.ctrl_byte):
return self.read_and_decrypt() return self.read_and_decrypt()
if not header.is_encrypted_transport(): if not header.is_encrypted_transport():
print("Trying to decrypt not encrypted message!") click.echo(
print( "Trying to decrypt not encrypted message!"
hexlify(header.to_bytes_init()).decode(), hexlify(raw_payload).decode() + hexlify(header.to_bytes_init() + raw_payload).decode(),
err=True,
) )
if not control_byte.is_ack(header.ctrl_byte): if not control_byte.is_ack(header.ctrl_byte):
@ -346,29 +357,36 @@ class ProtocolV2(ProtocolAndChannel):
def _read_until_valid_crc_check( def _read_until_valid_crc_check(
self, self,
) -> t.Tuple[PacketHeader, bytes]: ) -> t.Tuple[MessageHeader, bytes]:
is_valid = False is_valid = False
header, payload, chksum = thp_io.read(self.transport) header, payload, chksum = thp_io.read(self.transport)
while not is_valid: while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid: if not is_valid:
print(hexlify(header.to_bytes_init() + payload + chksum)) click.echo(
LOG.debug("Received a message with invalid checksum") "Received a message with an invalid checksum:"
+ hexlify(header.to_bytes_init() + payload + chksum),
err=True,
)
header, payload, chksum = thp_io.read(self.transport) header, payload, chksum = thp_io.read(self.transport)
return header, payload return header, payload
def _is_valid_channel_allocation_response( def _is_valid_channel_allocation_response(
self, header: PacketHeader, payload: bytes, original_nonce: bytes self, header: MessageHeader, payload: bytes, original_nonce: bytes
) -> bool: ) -> bool:
if not header.is_channel_allocation_response(): if not header.is_channel_allocation_response():
print("Received message is not a channel allocation response") click.echo(
"Received message is not a channel allocation response", err=True
)
return False return False
if len(payload) < 10: if len(payload) < 10:
print("Invalid channel allocation response payload") click.echo("Invalid channel allocation response payload", err=True)
return False return False
if payload[:8] != original_nonce: if payload[:8] != original_nonce:
print("Invalid channel allocation response payload (nonce mismatch)") click.echo(
"Invalid channel allocation response payload (nonce mismatch)", err=True
)
return False return False
return True return True

View File

@ -23,7 +23,7 @@ TREZOR_STATE_PAIRED = b"\x01"
BROADCAST_CHANNEL_ID = 0xFFFF BROADCAST_CHANNEL_ID = 0xFFFF
class PacketHeader: class MessageHeader:
format_str_init = ">BHH" format_str_init = ">BHH"
format_str_cont = ">BH" format_str_cont = ">BH"

View File

@ -3,7 +3,7 @@ from typing import Tuple
from .. import Transport from .. import Transport
from ..thp import checksum from ..thp import checksum
from .packet_header import PacketHeader from .message_header import MessageHeader
INIT_HEADER_LENGTH = 5 INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3 CONT_HEADER_LENGTH = 3
@ -14,7 +14,7 @@ CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum( def write_payload_to_wire_and_add_checksum(
transport: Transport, header: PacketHeader, transport_payload: bytes transport: Transport, header: MessageHeader, transport_payload: bytes
): ):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum data = transport_payload + chksum
@ -22,7 +22,7 @@ def write_payload_to_wire_and_add_checksum(
def write_payload_to_wire( def write_payload_to_wire(
transport: Transport, header: PacketHeader, transport_payload: bytes transport: Transport, header: MessageHeader, transport_payload: bytes
): ):
transport.open() transport.open()
buffer = bytearray(transport_payload) buffer = bytearray(transport_payload)
@ -40,8 +40,18 @@ def write_payload_to_wire(
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :] buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
def read(transport: Transport) -> Tuple[PacketHeader, bytes, bytes]: def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
"""
Reads from the given wire transport.
Returns `Tuple[MessageHeader, bytes, bytes]`:
1. `header` (`MessageHeader`): Header of the message.
2. `data` (`bytes`): Contents of the message (if any).
3. `checksum` (`bytes`): crc32 checksum of the header + data.
"""
buffer = bytearray() buffer = bytearray()
# Read header with first part of message data # Read header with first part of message data
header, first_chunk = read_first(transport) header, first_chunk = read_first(transport)
buffer.extend(first_chunk) buffer.extend(first_chunk)
@ -49,34 +59,31 @@ def read(transport: Transport) -> Tuple[PacketHeader, bytes, bytes]:
# Read the rest of the message # Read the rest of the message
while len(buffer) < header.data_length: while len(buffer) < header.data_length:
buffer.extend(read_next(transport, header.cid)) buffer.extend(read_next(transport, header.cid))
# print("buffer read (data):", hexlify(buffer).decode())
# print("buffer len (data):", datalen)
# TODO check checksum?? or do not strip ?
data_len = header.data_length - checksum.CHECKSUM_LENGTH data_len = header.data_length - checksum.CHECKSUM_LENGTH
return ( msg_data = buffer[:data_len]
header, chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
buffer[:data_len],
buffer[data_len : data_len + checksum.CHECKSUM_LENGTH], return (header, msg_data, chksum)
)
def read_first(transport: Transport) -> Tuple[PacketHeader, bytes]: def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
chunk = transport.read_chunk() chunk = transport.read_chunk()
try: try:
ctrl_byte, cid, data_length = struct.unpack( ctrl_byte, cid, data_length = struct.unpack(
PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
) )
except Exception: except Exception:
raise RuntimeError("Cannot parse header") raise RuntimeError("Cannot parse header")
data = chunk[INIT_HEADER_LENGTH:] data = chunk[INIT_HEADER_LENGTH:]
return PacketHeader(ctrl_byte, cid, data_length), data return MessageHeader(ctrl_byte, cid, data_length), data
def read_next(transport: Transport, cid: int) -> bytes: def read_next(transport: Transport, cid: int) -> bytes:
chunk = transport.read_chunk() chunk = transport.read_chunk()
ctrl_byte, read_cid = struct.unpack( ctrl_byte, read_cid = struct.unpack(
PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
) )
if ctrl_byte != CONTINUATION_PACKET: if ctrl_byte != CONTINUATION_PACKET:
raise RuntimeError("Continuation packet with incorrect control byte") raise RuntimeError("Continuation packet with incorrect control byte")