1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-12 06:06:07 +00:00
This commit is contained in:
M1nd3r 2025-03-03 10:02:02 +01:00
parent c68684062d
commit 03d5751f3f
3 changed files with 87 additions and 158 deletions

View File

@ -110,8 +110,13 @@ class Handshake:
encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey)
if __debug__:
log.debug(
__name__, "th1 - enc (key: %s, nonce: %d)", get_bytes_as_str(self.k), 0
__name__,
"th1 - enc (key: %s, nonce: %d, handshake_hash %s)",
get_bytes_as_str(self.k),
0,
get_bytes_as_str(self.h),
)
aes_ctx.auth(self.h)
tag_to_encrypted_key = aes_ctx.finish()
encrypted_trezor_static_pubkey = (

View File

@ -1,19 +1,17 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import typing as t
from binascii import hexlify
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from noise.connection import Keypair, NoiseConnection
from ... import exceptions, messages, protobuf
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, curve25519, thp_io
from ..thp import checksum, thp_io
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
@ -30,27 +28,6 @@ if t.TYPE_CHECKING:
MT = t.TypeVar("MT", bound=protobuf.MessageType)
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2Channel(Channel):
channel_id: int
channel_database: ChannelDatabase
@ -95,8 +72,8 @@ class ProtocolV2Channel(Channel):
protocol_version_minor=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.key_request,
key_response=self.key_response,
key_request=self.noise.noise_protocol.cipher_state_encrypt.k,
key_response=self.noise.noise_protocol.cipher_state_decrypt.k,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
@ -188,39 +165,40 @@ class ProtocolV2Channel(Channel):
device_properties = payload[10:]
return (channel_id, device_properties)
def _init_noise(self, randomness_static: bytes) -> None:
self.noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256")
self.noise.set_as_initiator()
self.noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static)
prologue = bytes(self.device_properties)
self.noise.set_prologue(prologue)
self.noise.start_handshake()
def _do_handshake(
self, credential: bytes | None = None, host_static_privkey: bytes | None = None
self,
credential: bytes | None = None,
host_static_randomness: bytes | None = None,
):
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._send_handshake_init_request(host_ephemeral_pubkey)
randomness_static = host_static_randomness or os.urandom(32)
self._init_noise(randomness_static)
self._send_handshake_init_request()
self._read_ack()
init_response = self._read_handshake_init_response()
trezor_ephemeral_pubkey = init_response[:32]
encrypted_trezor_static_pubkey = init_response[32:80]
noise_tag = init_response[80:96]
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# TODO check noise_tag is valid
ck = self._send_handshake_completion_request(
host_ephemeral_pubkey,
host_ephemeral_privkey,
trezor_ephemeral_pubkey,
encrypted_trezor_static_pubkey,
self._read_handshake_init_response()
self._send_handshake_completion_request(
credential,
host_static_privkey,
)
self._read_ack()
self._read_handshake_completion_response()
self.key_request, self.key_response = _hkdf(ck, b"")
self.key_request = self.noise.noise_protocol.cipher_state_encrypt.k
self.key_response = self.noise.noise_protocol.cipher_state_decrypt.k
self.nonce_request = 0
self.nonce_response = 1
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
def _send_handshake_init_request(self) -> None:
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
host_ephemeral_pubkey = self.noise.write_message()
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
@ -241,90 +219,49 @@ class ProtocolV2Channel(Channel):
"Received message is not a valid handshake init response message",
err=True,
)
self.noise.read_message(payload)
return payload
def _send_handshake_completion_request(
self,
host_ephemeral_pubkey: bytes,
host_ephemeral_privkey: bytes,
trezor_ephemeral_pubkey: bytes,
encrypted_trezor_static_pubkey: bytes,
credential: bytes | None = None,
host_static_privkey: bytes | None = None,
) -> bytes:
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
) -> None:
# TODO implement key recognition
# print(
# "TREZOR's static pubkey:\n",
# self.noise.noise_protocol.handshake_state.rs.public.public_bytes_raw(),
# )
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
except Exception as e:
click.echo(
f"Exception of type{type(e)}", err=True
) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials
if host_static_privkey is None:
host_static_privkey = curve25519.get_private_key(os.urandom(32))
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
host_pairing_credential=credential,
)
)
message2 = self.noise.write_message(payload=msg_data)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload[:-16])
ha_completion_req_header = MessageHeader(
0x12,
self.channel_id,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
len(message2) + CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
message2, # encrypted_host_static_pubkey + encrypted_payload,
)
self.handshake_hash = h
return ck
self.handshake_hash = self.noise.get_handshake_hash()
def _read_handshake_completion_response(self) -> None:
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
header, data = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
click.echo(
"Received message is not a valid handshake completion response",
err=True,
)
trezor_state = self.noise.decrypt(bytes(data))
# TODO handle trezor_state
print("trezor state:", trezor_state)
self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None):
@ -369,7 +306,6 @@ class ProtocolV2Channel(Channel):
ctrl_byte: int | None = None,
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
if ctrl_byte is None:
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
@ -378,9 +314,9 @@ class ProtocolV2Channel(Channel):
sid = session_id.to_bytes(1, "big")
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
encrypted_message = self.noise.encrypt(data)
header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
)
@ -417,11 +353,8 @@ class ProtocolV2Channel(Channel):
self._send_ack_1()
else:
self._send_ack_0()
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
message = self.noise.decrypt(bytes(raw_payload))
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]

View File

@ -34,7 +34,6 @@ from trezorlib.messages import (
)
from trezorlib.transport.thp import curve25519
from trezorlib.transport.thp.cpace import Cpace
from trezorlib.transport.thp.protocol_v2 import _hkdf
if t.TYPE_CHECKING:
P = tx.ParamSpec("P")
@ -53,18 +52,18 @@ def _prepare_protocol(client: Client) -> ProtocolV2Channel:
def _prepare_protocol_for_pairing(
client: Client, host_static_privkey: bytes | None = None
client: Client, host_static_randomness: bytes | None = None
) -> ProtocolV2Channel:
protocol = _prepare_protocol(client)
protocol._do_handshake(host_static_privkey=host_static_privkey)
protocol._do_handshake(host_static_randomness=host_static_randomness)
return protocol
def _get_encrypted_transport_protocol(
client: Client, host_static_privkey: bytes | None = None
client: Client, host_static_randomness: bytes | None = None
) -> ProtocolV2Channel:
protocol = _prepare_protocol_for_pairing(
client, host_static_privkey=host_static_privkey
client, host_static_randomness=host_static_randomness
)
protocol._do_pairing(client.debug)
return protocol
@ -105,39 +104,24 @@ def test_allocate_channel(client: Client) -> None:
def test_handshake(client: Client) -> None:
protocol = _prepare_protocol(client)
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
randomness_static = os.urandom(32)
protocol._do_channel_allocation()
protocol._send_handshake_init_request(host_ephemeral_pubkey)
protocol._read_ack()
init_response = protocol._read_handshake_init_response()
trezor_ephemeral_pubkey = init_response[:32]
encrypted_trezor_static_pubkey = init_response[32:80]
noise_tag = init_response[80:96]
# TODO check noise_tag is valid
ck = protocol._send_handshake_completion_request(
host_ephemeral_pubkey,
host_ephemeral_privkey,
trezor_ephemeral_pubkey,
encrypted_trezor_static_pubkey,
protocol._init_noise(
randomness_static=randomness_static,
)
protocol._send_handshake_init_request()
protocol._read_ack()
protocol._read_handshake_init_response()
protocol._send_handshake_completion_request()
protocol._read_ack()
protocol._read_handshake_completion_response()
protocol.key_request, protocol.key_response = _hkdf(ck, b"")
protocol.nonce_request = 0
protocol.nonce_response = 1
# TODO - without pairing, the client is damaged and results in fail of the following test
# so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error
protocol._do_pairing(client.debug)
# TODO the following is just to make style checker happy
assert noise_tag is not None
def test_pairing_qr_code(client: Client) -> None:
protocol = _prepare_protocol_for_pairing(client)
@ -293,7 +277,8 @@ def test_credential_phase(client: Client) -> None:
_nfc_pairing(client, protocol)
# Request credential with confirmation after pairing
host_static_privkey = curve25519.get_private_key(os.urandom(32))
randomness_static = os.urandom(32)
host_static_privkey = curve25519.get_private_key(randomness_static)
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
protocol._send_message(
ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False)
@ -308,7 +293,7 @@ def test_credential_phase(client: Client) -> None:
# Connect using credential with confirmation
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential, host_static_privkey)
protocol._do_handshake(credential, randomness_static)
protocol._send_message(ThpEndRequest())
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "connection_request"
@ -318,7 +303,8 @@ def test_credential_phase(client: Client) -> None:
# Delete channel from the device by sending badly encrypted message
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
protocol.nonce_request = 250
protocol.noise.noise_protocol.cipher_state_encrypt.n = 250
protocol._send_message(ButtonAck())
with pytest.raises(Exception) as e:
protocol.read(1)
@ -327,7 +313,7 @@ def test_credential_phase(client: Client) -> None:
# Connect using credential with confirmation and ask for autoconnect credential.
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential, host_static_privkey)
protocol._do_handshake(credential, randomness_static)
protocol._send_message(
ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True)
)
@ -345,7 +331,7 @@ def test_credential_phase(client: Client) -> None:
# Connect using credential with confirmation
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential, host_static_privkey)
protocol._do_handshake(credential, randomness_static)
# Confirmation dialog is not shown as channel in ENCRYPTED TRANSPORT state with the same
# host static public key is still available in Trezor's cache. (Channel replacement is triggered.)
protocol._send_message(ThpEndRequest())
@ -354,13 +340,14 @@ def test_credential_phase(client: Client) -> None:
# Connect using autoconnect credential
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential_auto, host_static_privkey)
protocol._do_handshake(credential_auto, randomness_static)
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
# Delete channel from the device by sending badly encrypted message
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
protocol.nonce_request = 250
protocol.noise.noise_protocol.cipher_state_encrypt.n = 100
protocol._send_message(ButtonAck())
with pytest.raises(Exception) as e:
protocol.read(1)
@ -369,7 +356,7 @@ def test_credential_phase(client: Client) -> None:
# Connect using autoconnect credential - should work the same as above
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential_auto, host_static_privkey)
protocol._do_handshake(credential_auto, randomness_static)
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
@ -378,23 +365,25 @@ def test_credential_phase(client: Client) -> None:
def test_channel_replacement(client: Client) -> None:
assert client.features.passphrase_protection is True
host_static_privkey = curve25519.get_private_key(os.urandom(32))
host_static_privkey_2 = curve25519.get_private_key(os.urandom(32))
host_static_randomness = os.urandom(32)
host_static_randomness_2 = os.urandom(32)
host_static_privkey = curve25519.get_private_key(host_static_randomness)
host_static_privkey_2 = curve25519.get_private_key(host_static_randomness_2)
assert host_static_privkey != host_static_privkey_2
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
client.protocol = _get_encrypted_transport_protocol(client, host_static_randomness)
session = client.get_session(passphrase="TREZOR", session_id=20)
session = client.get_session(passphrase="TREZOR", session_id=b"\x10")
address = get_test_address(session)
session_2 = client.get_session(passphrase="ROZERT", session_id=30)
session_2 = client.get_session(passphrase="ROZERT", session_id=b"\x20")
address_2 = get_test_address(session_2)
assert address != address_2
# create new channel using the same host_static_privkey
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
session_3 = client.get_session(passphrase="OKIDOKI", session_id=40)
client.protocol = _get_encrypted_transport_protocol(client, host_static_randomness)
session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30")
address_3 = get_test_address(session_3)
assert address_3 != address_2
@ -405,7 +394,9 @@ def test_channel_replacement(client: Client) -> None:
assert address_3 == new_address_3
# create new channel using different host_static_privkey
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey_2)
client.protocol = _get_encrypted_transport_protocol(
client, host_static_randomness_2
)
with pytest.raises(exceptions.TrezorFailure) as e_1:
_ = get_test_address(session)
assert str(e_1.value.message) == "Invalid session"
@ -414,6 +405,6 @@ def test_channel_replacement(client: Client) -> None:
_ = get_test_address(session_3)
assert str(e_2.value.message) == "Invalid session"
session_4 = client.get_session(passphrase="TREZOR", session_id=80)
session_4 = client.get_session(passphrase="TREZOR", session_id=b"\x40")
super_new_address = get_test_address(session_4)
assert address == super_new_address