mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-30 17:21:21 +00:00
refactor(python): improve structure of protocol_v2 channel establishment
[no changelog]
This commit is contained in:
parent
6b3fa22c85
commit
e39de541bd
@ -132,40 +132,78 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
def _establish_new_channel(self) -> None:
|
def _establish_new_channel(self) -> None:
|
||||||
self.sync_bit_send = 0
|
self.sync_bit_send = 0
|
||||||
self.sync_bit_receive = 0
|
self.sync_bit_receive = 0
|
||||||
# Send channel allocation request
|
|
||||||
channel_id_request_nonce = os.urandom(8)
|
# Generate ephemeral keys
|
||||||
|
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||||
|
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||||
|
|
||||||
|
self._do_channel_allocation()
|
||||||
|
|
||||||
|
self._do_handshake(host_ephemeral_privkey, host_ephemeral_pubkey)
|
||||||
|
|
||||||
|
self._do_pairing()
|
||||||
|
|
||||||
|
def _do_channel_allocation(self) -> None:
|
||||||
|
channel_allocation_nonce = os.urandom(8)
|
||||||
|
self._send_channel_allocation_request(channel_allocation_nonce)
|
||||||
|
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
|
||||||
|
self.channel_id = cid
|
||||||
|
self.device_properties = dp
|
||||||
|
|
||||||
|
def _send_channel_allocation_request(self, nonce: bytes):
|
||||||
thp_io.write_payload_to_wire_and_add_checksum(
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
self.transport,
|
self.transport,
|
||||||
MessageHeader.get_channel_allocation_request_header(12),
|
MessageHeader.get_channel_allocation_request_header(12),
|
||||||
channel_id_request_nonce,
|
nonce,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Read channel allocation response
|
def _read_channel_allocation_response(
|
||||||
|
self, expected_nonce: bytes
|
||||||
|
) -> tuple[int, bytes]:
|
||||||
header, payload = self._read_until_valid_crc_check()
|
header, payload = self._read_until_valid_crc_check()
|
||||||
if not self._is_valid_channel_allocation_response(
|
if not self._is_valid_channel_allocation_response(
|
||||||
header, payload, channel_id_request_nonce
|
header, payload, expected_nonce
|
||||||
):
|
):
|
||||||
# TODO raise exception here, I guess
|
|
||||||
raise Exception("Invalid channel allocation response.")
|
raise Exception("Invalid channel allocation response.")
|
||||||
|
|
||||||
self.channel_id = int.from_bytes(payload[8:10], "big")
|
channel_id = int.from_bytes(payload[8:10], "big")
|
||||||
self.device_properties = payload[10:]
|
device_properties = payload[10:]
|
||||||
|
return (channel_id, device_properties)
|
||||||
|
|
||||||
# Send handshake init request
|
def _do_handshake(
|
||||||
|
self, host_ephemeral_privkey: bytes, host_ephemeral_pubkey: bytes
|
||||||
|
):
|
||||||
|
self._send_handshake_init_request(host_ephemeral_pubkey)
|
||||||
|
self._read_ack()
|
||||||
|
init_response = self._read_handshake_init_response()
|
||||||
|
# assert len(init_response) == 96
|
||||||
|
trezor_ephemeral_pubkey = init_response[:32]
|
||||||
|
encrypted_trezor_static_pubkey = init_response[32:80]
|
||||||
|
noise_tag = init_response[80:96]
|
||||||
|
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||||
|
|
||||||
|
# TODO check noise_tag is valid
|
||||||
|
|
||||||
|
ck = self._send_handshake_completion_request(
|
||||||
|
host_ephemeral_pubkey,
|
||||||
|
host_ephemeral_privkey,
|
||||||
|
trezor_ephemeral_pubkey,
|
||||||
|
encrypted_trezor_static_pubkey,
|
||||||
|
)
|
||||||
|
self._read_ack()
|
||||||
|
self._read_handshake_completion_response()
|
||||||
|
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||||
|
self.nonce_request = 0
|
||||||
|
self.nonce_response = 1
|
||||||
|
|
||||||
|
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
|
||||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
|
||||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
|
||||||
|
|
||||||
thp_io.write_payload_to_wire_and_add_checksum(
|
thp_io.write_payload_to_wire_and_add_checksum(
|
||||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||||
)
|
)
|
||||||
|
|
||||||
# Read ACK
|
def _read_handshake_init_response(self) -> bytes:
|
||||||
header, payload = self._read_until_valid_crc_check()
|
|
||||||
if not header.is_ack() or len(payload) > 0:
|
|
||||||
click.echo("Received message is not a valid ACK", err=True)
|
|
||||||
|
|
||||||
# Read handshake init response
|
|
||||||
header, payload = self._read_until_valid_crc_check()
|
header, payload = self._read_until_valid_crc_check()
|
||||||
self._send_ack_0()
|
self._send_ack_0()
|
||||||
|
|
||||||
@ -180,15 +218,16 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
"Received message is not a valid handshake init response message",
|
"Received message is not a valid handshake init response message",
|
||||||
err=True,
|
err=True,
|
||||||
)
|
)
|
||||||
|
return payload
|
||||||
|
|
||||||
trezor_ephemeral_pubkey = payload[:32]
|
def _send_handshake_completion_request(
|
||||||
encrypted_trezor_static_pubkey = payload[32:80]
|
self,
|
||||||
noise_tag = payload[80:96]
|
host_ephemeral_pubkey: bytes,
|
||||||
|
host_ephemeral_privkey: bytes,
|
||||||
# TODO check noise tag
|
trezor_ephemeral_pubkey: bytes,
|
||||||
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
encrypted_trezor_static_pubkey: bytes,
|
||||||
|
credential: bytes | None = None,
|
||||||
# Prepare and send handshake completion request
|
) -> bytes:
|
||||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||||
@ -232,7 +271,8 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
messages.ThpHandshakeCompletionReqNoisePayload(
|
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||||
pairing_methods=[
|
pairing_methods=[
|
||||||
messages.ThpPairingMethod.NoMethod,
|
messages.ThpPairingMethod.NoMethod,
|
||||||
]
|
],
|
||||||
|
host_pairing_credential=credential,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -252,12 +292,9 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
ha_completion_req_header,
|
ha_completion_req_header,
|
||||||
encrypted_host_static_pubkey + encrypted_payload,
|
encrypted_host_static_pubkey + encrypted_payload,
|
||||||
)
|
)
|
||||||
|
return ck
|
||||||
|
|
||||||
# Read ACK
|
def _read_handshake_completion_response(self) -> None:
|
||||||
header, payload = self._read_until_valid_crc_check()
|
|
||||||
if not header.is_ack() or len(payload) > 0:
|
|
||||||
click.echo("Received message is not a valid ACK", err=True)
|
|
||||||
|
|
||||||
# Read handshake completion response, ignore payload as we do not care about the state
|
# Read handshake completion response, ignore payload as we do not care about the state
|
||||||
header, _ = self._read_until_valid_crc_check()
|
header, _ = self._read_until_valid_crc_check()
|
||||||
if not header.is_handshake_comp_response():
|
if not header.is_handshake_comp_response():
|
||||||
@ -267,10 +304,7 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
)
|
)
|
||||||
self._send_ack_1()
|
self._send_ack_1()
|
||||||
|
|
||||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
def _do_pairing(self):
|
||||||
self.nonce_request = 0
|
|
||||||
self.nonce_response = 1
|
|
||||||
|
|
||||||
# Send StartPairingReqest message
|
# Send StartPairingReqest message
|
||||||
message = messages.ThpStartPairingRequest()
|
message = messages.ThpStartPairingRequest()
|
||||||
message_type, message_data = self.mapping.encode(message)
|
message_type, message_data = self.mapping.encode(message)
|
||||||
@ -278,17 +312,20 @@ class ProtocolV2(ProtocolAndChannel):
|
|||||||
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data)
|
||||||
|
|
||||||
# Read ACK
|
# Read ACK
|
||||||
header, payload = self._read_until_valid_crc_check()
|
self._read_ack()
|
||||||
if not header.is_ack() or len(payload) > 0:
|
|
||||||
click.echo("Received message is not a valid ACK", err=True)
|
|
||||||
|
|
||||||
# Read
|
# Read ThpEndResponse
|
||||||
_, msg_type, msg_data = self.read_and_decrypt()
|
_, msg_type, msg_data = self.read_and_decrypt()
|
||||||
maaa = self.mapping.decode(msg_type, msg_data)
|
maaa = self.mapping.decode(msg_type, msg_data)
|
||||||
|
|
||||||
assert isinstance(maaa, messages.ThpEndResponse)
|
assert isinstance(maaa, messages.ThpEndResponse)
|
||||||
self._has_valid_channel = True
|
self._has_valid_channel = True
|
||||||
|
|
||||||
|
def _read_ack(self):
|
||||||
|
header, payload = self._read_until_valid_crc_check()
|
||||||
|
if not header.is_ack() or len(payload) > 0:
|
||||||
|
click.echo("Received message is not a valid ACK", err=True)
|
||||||
|
|
||||||
def _send_ack_0(self):
|
def _send_ack_0(self):
|
||||||
LOG.debug("sending ack 0")
|
LOG.debug("sending ack 0")
|
||||||
header = MessageHeader(0x20, self.channel_id, 4)
|
header = MessageHeader(0x20, self.channel_id, 4)
|
||||||
|
Loading…
Reference in New Issue
Block a user