1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-23 11:35:42 +00:00

fix(python): improve protocol_v2, remove its channel database

This commit is contained in:
M1nd3r 2025-03-13 14:11:32 +01:00
parent b5b0a354a9
commit 50c1299fb3
2 changed files with 15 additions and 55 deletions

View File

@ -12,11 +12,9 @@ from ... import exceptions, messages, protobuf
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, thp_io
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import Channel
LOG = logging.getLogger(__name__)
@ -30,11 +28,6 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType)
class ProtocolV2Channel(Channel):
channel_id: int
channel_database: ChannelDatabase
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
sync_bit_send: int
sync_bit_receive: int
handshake_hash: bytes
@ -46,52 +39,23 @@ class ProtocolV2Channel(Channel):
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.channel_database: ChannelDatabase = get_channel_db()
super().__init__(transport, mapping)
if channel_data is not None:
self.channel_id = channel_data.channel_id
self.key_request = bytes.fromhex(channel_data.key_request)
self.key_response = bytes.fromhex(channel_data.key_response)
self.nonce_request = channel_data.nonce_request
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
self._has_valid_channel = True
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel:
if not self._has_valid_channel:
self._establish_new_channel(helper_debug)
return self
def get_channel_data(self) -> ChannelData:
return ChannelData(
protocol_version_major=2,
protocol_version_minor=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.noise.noise_protocol.cipher_state_encrypt.k,
key_response=self.noise.noise_protocol.cipher_state_decrypt.k,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
sync_bit_send=self.sync_bit_send,
handshake_hash=self.handshake_hash,
)
def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on a different session.")
self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
self.channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
@ -166,13 +130,13 @@ class ProtocolV2Channel(Channel):
return (channel_id, device_properties)
def _init_noise(self, randomness_static: bytes) -> None:
self.noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256")
self.noise.set_as_initiator()
self.noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static)
self._noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256")
self._noise.set_as_initiator()
self._noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static)
prologue = bytes(self.device_properties)
self.noise.set_prologue(prologue)
self.noise.start_handshake()
self._noise.set_prologue(prologue)
self._noise.start_handshake()
def _do_handshake(
self,
@ -191,14 +155,10 @@ class ProtocolV2Channel(Channel):
)
self._read_ack()
self._read_handshake_completion_response()
self.key_request = self.noise.noise_protocol.cipher_state_encrypt.k
self.key_response = self.noise.noise_protocol.cipher_state_decrypt.k
self.nonce_request = 0
self.nonce_response = 1
def _send_handshake_init_request(self) -> None:
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
host_ephemeral_pubkey = self.noise.write_message()
host_ephemeral_pubkey = self._noise.write_message()
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
@ -219,7 +179,7 @@ class ProtocolV2Channel(Channel):
"Received message is not a valid handshake init response message",
err=True,
)
self.noise.read_message(payload)
self._noise.read_message(payload)
return payload
def _send_handshake_completion_request(
@ -237,7 +197,7 @@ class ProtocolV2Channel(Channel):
host_pairing_credential=credential,
)
)
message2 = self.noise.write_message(payload=msg_data)
message2 = self._noise.write_message(payload=msg_data)
ha_completion_req_header = MessageHeader(
0x12,
@ -249,7 +209,7 @@ class ProtocolV2Channel(Channel):
ha_completion_req_header,
message2, # encrypted_host_static_pubkey + encrypted_payload,
)
self.handshake_hash = self.noise.get_handshake_hash()
self.handshake_hash = self._noise.get_handshake_hash()
def _read_handshake_completion_response(self) -> None:
# Read handshake completion response, ignore payload as we do not care about the state
@ -259,9 +219,10 @@ class ProtocolV2Channel(Channel):
"Received message is not a valid handshake completion response",
err=True,
)
trezor_state = self.noise.decrypt(bytes(data))
trezor_state = self._noise.decrypt(bytes(data))
# TODO handle trezor_state
print("trezor state:", trezor_state)
assert trezor_state == b"\x00" or trezor_state == b"\x01"
self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None):
@ -305,7 +266,6 @@ class ProtocolV2Channel(Channel):
message_data: bytes,
ctrl_byte: int | None = None,
) -> None:
assert self.key_request is not None
if ctrl_byte is None:
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
@ -315,7 +275,7 @@ class ProtocolV2Channel(Channel):
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
encrypted_message = self.noise.encrypt(data)
encrypted_message = self._noise.encrypt(data)
header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
@ -354,7 +314,7 @@ class ProtocolV2Channel(Channel):
else:
self._send_ack_0()
message = self.noise.decrypt(bytes(raw_payload))
message = self._noise.decrypt(bytes(raw_payload))
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]

View File

@ -302,7 +302,7 @@ def test_credential_phase(client: Client) -> None:
# Delete channel from the device by sending badly encrypted message
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
protocol.noise.noise_protocol.cipher_state_encrypt.n = 250
protocol._noise.noise_protocol.cipher_state_encrypt.n = 250
protocol._send_message(ButtonAck())
with pytest.raises(Exception) as e:
@ -351,7 +351,7 @@ def test_credential_phase(client: Client) -> None:
# Delete channel from the device by sending badly encrypted message
# This is done to prevent channel replacement and trigerring of autoconnect false -> true
protocol.noise.noise_protocol.cipher_state_encrypt.n = 100
protocol._noise.noise_protocol.cipher_state_encrypt.n = 100
protocol._send_message(ButtonAck())
with pytest.raises(Exception) as e: