mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-07 14:00:57 +00:00
chore(trezorlib): clean trezolib THP, part 1
[no changelog]
This commit is contained in:
parent
386595cab2
commit
594ec26fa7
@ -8,6 +8,7 @@ import typing as t
|
||||
from binascii import hexlify
|
||||
from enum import IntEnum
|
||||
|
||||
import click
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ... import exceptions, messages
|
||||
@ -15,7 +16,7 @@ from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
from ..thp import checksum, curve25519, thp_io
|
||||
from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.packet_header import PacketHeader
|
||||
from ..thp.message_header import MessageHeader
|
||||
from . import channel_database, control_byte
|
||||
from .channel_data import ChannelData
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
@ -97,7 +98,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||
if sid != session_id:
|
||||
raise Exception("Received messsage on different session.")
|
||||
raise Exception("Received messsage on a different session.")
|
||||
channel_database.save_channel(self)
|
||||
return self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
@ -133,7 +134,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
channel_id_request_nonce = os.urandom(8)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
PacketHeader.get_channel_allocation_request_header(12),
|
||||
MessageHeader.get_channel_allocation_request_header(12),
|
||||
channel_id_request_nonce,
|
||||
)
|
||||
|
||||
@ -142,13 +143,14 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
if not self._is_valid_channel_allocation_response(
|
||||
header, payload, channel_id_request_nonce
|
||||
):
|
||||
print("TODO raise exception here, I guess")
|
||||
# TODO raise exception here, I guess
|
||||
raise Exception("Invalid channel allocation response.")
|
||||
|
||||
self.channel_id = int.from_bytes(payload[8:10], "big")
|
||||
self.device_properties = payload[10:]
|
||||
|
||||
# Send handshake init request
|
||||
ha_init_req_header = PacketHeader(0, self.channel_id, 36)
|
||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
@ -159,14 +161,17 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
# Read handshake init response
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
self._send_ack_0()
|
||||
|
||||
if not header.is_handshake_init_response():
|
||||
print("Received message is not a valid handshake init response message")
|
||||
click.echo(
|
||||
"Received message is not a valid handshake init response message",
|
||||
err=True,
|
||||
)
|
||||
|
||||
trezor_ephemeral_pubkey = payload[:32]
|
||||
encrypted_trezor_static_pubkey = payload[32:80]
|
||||
@ -193,7 +198,9 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
IV_1, encrypted_trezor_static_pubkey, h
|
||||
)
|
||||
except Exception as e:
|
||||
print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
click.echo(
|
||||
f"Exception of type{type(e)}", err=True
|
||||
) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||
@ -225,7 +232,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
|
||||
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||
h = _sha256_of_two(h, encrypted_payload)
|
||||
ha_completion_req_header = PacketHeader(
|
||||
ha_completion_req_header = MessageHeader(
|
||||
0x12,
|
||||
self.channel_id,
|
||||
len(encrypted_host_static_pubkey)
|
||||
@ -241,12 +248,15 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, _ = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
print("Received message is not a valid handshake completion response")
|
||||
click.echo(
|
||||
"Received message is not a valid handshake completion response",
|
||||
err=True,
|
||||
)
|
||||
self._send_ack_1()
|
||||
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
@ -262,7 +272,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
print("Received message is not a valid ACK ")
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
# Read
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
@ -273,12 +283,12 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
|
||||
def _send_ack_0(self):
|
||||
LOG.debug("sending ack 0")
|
||||
header = PacketHeader(0x20, self.channel_id, 4)
|
||||
header = MessageHeader(0x20, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _send_ack_1(self):
|
||||
LOG.debug("sending ack 1")
|
||||
header = PacketHeader(0x28, self.channel_id, 4)
|
||||
header = MessageHeader(0x28, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _encrypt_and_write(
|
||||
@ -301,7 +311,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
self.nonce_request += 1
|
||||
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||
header = PacketHeader(
|
||||
header = MessageHeader(
|
||||
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
)
|
||||
|
||||
@ -314,9 +324,10 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
if control_byte.is_ack(header.ctrl_byte):
|
||||
return self.read_and_decrypt()
|
||||
if not header.is_encrypted_transport():
|
||||
print("Trying to decrypt not encrypted message!")
|
||||
print(
|
||||
hexlify(header.to_bytes_init()).decode(), hexlify(raw_payload).decode()
|
||||
click.echo(
|
||||
"Trying to decrypt not encrypted message!"
|
||||
+ hexlify(header.to_bytes_init() + raw_payload).decode(),
|
||||
err=True,
|
||||
)
|
||||
|
||||
if not control_byte.is_ack(header.ctrl_byte):
|
||||
@ -346,29 +357,36 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
|
||||
def _read_until_valid_crc_check(
|
||||
self,
|
||||
) -> t.Tuple[PacketHeader, bytes]:
|
||||
) -> t.Tuple[MessageHeader, bytes]:
|
||||
is_valid = False
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
while not is_valid:
|
||||
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||
if not is_valid:
|
||||
print(hexlify(header.to_bytes_init() + payload + chksum))
|
||||
LOG.debug("Received a message with invalid checksum")
|
||||
click.echo(
|
||||
"Received a message with an invalid checksum:"
|
||||
+ hexlify(header.to_bytes_init() + payload + chksum),
|
||||
err=True,
|
||||
)
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
|
||||
return header, payload
|
||||
|
||||
def _is_valid_channel_allocation_response(
|
||||
self, header: PacketHeader, payload: bytes, original_nonce: bytes
|
||||
self, header: MessageHeader, payload: bytes, original_nonce: bytes
|
||||
) -> bool:
|
||||
if not header.is_channel_allocation_response():
|
||||
print("Received message is not a channel allocation response")
|
||||
click.echo(
|
||||
"Received message is not a channel allocation response", err=True
|
||||
)
|
||||
return False
|
||||
if len(payload) < 10:
|
||||
print("Invalid channel allocation response payload")
|
||||
click.echo("Invalid channel allocation response payload", err=True)
|
||||
return False
|
||||
if payload[:8] != original_nonce:
|
||||
print("Invalid channel allocation response payload (nonce mismatch)")
|
||||
click.echo(
|
||||
"Invalid channel allocation response payload (nonce mismatch)", err=True
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -23,7 +23,7 @@ TREZOR_STATE_PAIRED = b"\x01"
|
||||
BROADCAST_CHANNEL_ID = 0xFFFF
|
||||
|
||||
|
||||
class PacketHeader:
|
||||
class MessageHeader:
|
||||
format_str_init = ">BHH"
|
||||
format_str_cont = ">BH"
|
||||
|
@ -3,7 +3,7 @@ from typing import Tuple
|
||||
|
||||
from .. import Transport
|
||||
from ..thp import checksum
|
||||
from .packet_header import PacketHeader
|
||||
from .message_header import MessageHeader
|
||||
|
||||
INIT_HEADER_LENGTH = 5
|
||||
CONT_HEADER_LENGTH = 3
|
||||
@ -14,7 +14,7 @@ CONTINUATION_PACKET = 0x80
|
||||
|
||||
|
||||
def write_payload_to_wire_and_add_checksum(
|
||||
transport: Transport, header: PacketHeader, transport_payload: bytes
|
||||
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||
):
|
||||
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
|
||||
data = transport_payload + chksum
|
||||
@ -22,7 +22,7 @@ def write_payload_to_wire_and_add_checksum(
|
||||
|
||||
|
||||
def write_payload_to_wire(
|
||||
transport: Transport, header: PacketHeader, transport_payload: bytes
|
||||
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||
):
|
||||
transport.open()
|
||||
buffer = bytearray(transport_payload)
|
||||
@ -40,8 +40,18 @@ def write_payload_to_wire(
|
||||
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
|
||||
|
||||
|
||||
def read(transport: Transport) -> Tuple[PacketHeader, bytes, bytes]:
|
||||
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
|
||||
"""
|
||||
Reads from the given wire transport.
|
||||
|
||||
Returns `Tuple[MessageHeader, bytes, bytes]`:
|
||||
1. `header` (`MessageHeader`): Header of the message.
|
||||
2. `data` (`bytes`): Contents of the message (if any).
|
||||
3. `checksum` (`bytes`): crc32 checksum of the header + data.
|
||||
|
||||
"""
|
||||
buffer = bytearray()
|
||||
|
||||
# Read header with first part of message data
|
||||
header, first_chunk = read_first(transport)
|
||||
buffer.extend(first_chunk)
|
||||
@ -49,34 +59,31 @@ def read(transport: Transport) -> Tuple[PacketHeader, bytes, bytes]:
|
||||
# Read the rest of the message
|
||||
while len(buffer) < header.data_length:
|
||||
buffer.extend(read_next(transport, header.cid))
|
||||
# print("buffer read (data):", hexlify(buffer).decode())
|
||||
# print("buffer len (data):", datalen)
|
||||
# TODO check checksum?? or do not strip ?
|
||||
|
||||
data_len = header.data_length - checksum.CHECKSUM_LENGTH
|
||||
return (
|
||||
header,
|
||||
buffer[:data_len],
|
||||
buffer[data_len : data_len + checksum.CHECKSUM_LENGTH],
|
||||
)
|
||||
msg_data = buffer[:data_len]
|
||||
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
|
||||
|
||||
return (header, msg_data, chksum)
|
||||
|
||||
|
||||
def read_first(transport: Transport) -> Tuple[PacketHeader, bytes]:
|
||||
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
|
||||
chunk = transport.read_chunk()
|
||||
try:
|
||||
ctrl_byte, cid, data_length = struct.unpack(
|
||||
PacketHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[INIT_HEADER_LENGTH:]
|
||||
return PacketHeader(ctrl_byte, cid, data_length), data
|
||||
return MessageHeader(ctrl_byte, cid, data_length), data
|
||||
|
||||
|
||||
def read_next(transport: Transport, cid: int) -> bytes:
|
||||
chunk = transport.read_chunk()
|
||||
ctrl_byte, read_cid = struct.unpack(
|
||||
PacketHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||
)
|
||||
if ctrl_byte != CONTINUATION_PACKET:
|
||||
raise RuntimeError("Continuation packet with incorrect control byte")
|
||||
|
Loading…
Reference in New Issue
Block a user