mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-27 07:40:59 +00:00
refactor(python): improve structure of protocol_v2 channel establishment
[no changelog]
This commit is contained in:
parent
6b3fa22c85
commit
e39de541bd
@ -132,40 +132,78 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
def _establish_new_channel(self) -> None:
|
||||
self.sync_bit_send = 0
|
||||
self.sync_bit_receive = 0
|
||||
# Send channel allocation request
|
||||
channel_id_request_nonce = os.urandom(8)
|
||||
|
||||
# Generate ephemeral keys
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
self._do_channel_allocation()
|
||||
|
||||
self._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
|
||||
|
||||
self._do_pairing()
|
||||
|
||||
def _do_channel_allocation(self) -> None:
|
||||
channel_allocation_nonce = os.urandom(8)
|
||||
self._send_channel_allocation_request(channel_allocation_nonce)
|
||||
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
|
||||
self.channel_id = cid
|
||||
self.device_properties = dp
|
||||
|
||||
def _send_channel_allocation_request(self, nonce: bytes):
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
MessageHeader.get_channel_allocation_request_header(12),
|
||||
channel_id_request_nonce,
|
||||
nonce,
|
||||
)
|
||||
|
||||
# Read channel allocation response
|
||||
def _read_channel_allocation_response(
|
||||
self, expected_nonce: bytes
|
||||
) -> tuple[int, bytes]:
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not self._is_valid_channel_allocation_response(
|
||||
header, payload, channel_id_request_nonce
|
||||
header, payload, expected_nonce
|
||||
):
|
||||
# TODO raise exception here, I guess
|
||||
raise Exception("Invalid channel allocation response.")
|
||||
|
||||
self.channel_id = int.from_bytes(payload[8:10], "big")
|
||||
self.device_properties = payload[10:]
|
||||
channel_id = int.from_bytes(payload[8:10], "big")
|
||||
device_properties = payload[10:]
|
||||
return (channel_id, device_properties)
|
||||
|
||||
# Send handshake init request
|
||||
def _do_handshake(
|
||||
self, host_ephemeral_privkey: bytes, host_ephemeral_pubkey: bytes
|
||||
):
|
||||
self._send_handshake_init_request(host_ephemeral_pubkey)
|
||||
self._read_ack()
|
||||
init_response = self._read_handshake_init_response()
|
||||
# assert len(init_response) == 96
|
||||
trezor_ephemeral_pubkey = init_response[:32]
|
||||
encrypted_trezor_static_pubkey = init_response[32:80]
|
||||
noise_tag = init_response[80:96]
|
||||
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||
|
||||
# TODO check noise_tag is valid
|
||||
|
||||
ck = self._send_handshake_completion_request(
|
||||
host_ephemeral_pubkey,
|
||||
host_ephemeral_privkey,
|
||||
trezor_ephemeral_pubkey,
|
||||
encrypted_trezor_static_pubkey,
|
||||
)
|
||||
self._read_ack()
|
||||
self._read_handshake_completion_response()
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
self.nonce_request = 0
|
||||
self.nonce_response = 1
|
||||
|
||||
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
|
||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||
)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
# Read handshake init response
|
||||
def _read_handshake_init_response(self) -> bytes:
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
self._send_ack_0()
|
||||
|
||||
@ -180,15 +218,16 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
"Received message is not a valid handshake init response message",
|
||||
err=True,
|
||||
)
|
||||
return payload
|
||||
|
||||
trezor_ephemeral_pubkey = payload[:32]
|
||||
encrypted_trezor_static_pubkey = payload[32:80]
|
||||
noise_tag = payload[80:96]
|
||||
|
||||
# TODO check noise tag
|
||||
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||
|
||||
# Prepare and send handshake completion request
|
||||
def _send_handshake_completion_request(
|
||||
self,
|
||||
host_ephemeral_pubkey: bytes,
|
||||
host_ephemeral_privkey: bytes,
|
||||
trezor_ephemeral_pubkey: bytes,
|
||||
encrypted_trezor_static_pubkey: bytes,
|
||||
credential: bytes | None = None,
|
||||
) -> bytes:
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
@ -232,7 +271,8 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||
pairing_methods=[
|
||||
messages.ThpPairingMethod.NoMethod,
|
||||
]
|
||||
],
|
||||
host_pairing_credential=credential,
|
||||
)
|
||||
)
|
||||
|
||||
@ -252,12 +292,9 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
ha_completion_req_header,
|
||||
encrypted_host_static_pubkey + encrypted_payload,
|
||||
)
|
||||
return ck
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
def _read_handshake_completion_response(self) -> None:
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, _ = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
@ -267,10 +304,7 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
)
|
||||
self._send_ack_1()
|
||||
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
self.nonce_request = 0
|
||||
self.nonce_response = 1
|
||||
|
||||
def _do_pairing(self):
|
||||
# Send StartPairingReqest message
|
||||
message = messages.ThpStartPairingRequest()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
@ -278,17 +312,20 @@ class ProtocolV2(ProtocolAndChannel):
|
||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||
|
||||
# Read ACK
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
self._read_ack()
|
||||
|
||||
# Read
|
||||
# Read ThpEndResponse
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
maaa = self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
assert isinstance(maaa, messages.ThpEndResponse)
|
||||
self._has_valid_channel = True
|
||||
|
||||
def _read_ack(self):
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
def _send_ack_0(self):
|
||||
LOG.debug("sending ack 0")
|
||||
header = MessageHeader(0x20, self.channel_id, 4)
|
||||
|
Loading…
Reference in New Issue
Block a user