diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 8b3a2a3828..160054b30e 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -12,11 +12,9 @@ from ... import exceptions, messages, protobuf from ...mapping import ProtobufMapping from .. import Transport from ..thp import checksum, thp_io -from ..thp.channel_data import ChannelData from ..thp.checksum import CHECKSUM_LENGTH from ..thp.message_header import MessageHeader from . import control_byte -from .channel_database import ChannelDatabase, get_channel_db from .protocol_and_channel import Channel LOG = logging.getLogger(__name__) @@ -30,11 +28,6 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType) class ProtocolV2Channel(Channel): channel_id: int - channel_database: ChannelDatabase - key_request: bytes - key_response: bytes - nonce_request: int - nonce_response: int sync_bit_send: int sync_bit_receive: int handshake_hash: bytes @@ -46,52 +39,23 @@ class ProtocolV2Channel(Channel): self, transport: Transport, mapping: ProtobufMapping, - channel_data: ChannelData | None = None, ) -> None: - self.channel_database: ChannelDatabase = get_channel_db() super().__init__(transport, mapping) - if channel_data is not None: - self.channel_id = channel_data.channel_id - self.key_request = bytes.fromhex(channel_data.key_request) - self.key_response = bytes.fromhex(channel_data.key_response) - self.nonce_request = channel_data.nonce_request - self.nonce_response = channel_data.nonce_response - self.sync_bit_receive = channel_data.sync_bit_receive - self.sync_bit_send = channel_data.sync_bit_send - self.handshake_hash = bytes.fromhex(channel_data.handshake_hash) - self._has_valid_channel = True def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2Channel: if not self._has_valid_channel: self._establish_new_channel(helper_debug) return self - def get_channel_data(self) -> ChannelData: - return ChannelData( - protocol_version_major=2, - protocol_version_minor=2, - transport_path=self.transport.get_path(), - channel_id=self.channel_id, - key_request=self.noise.noise_protocol.cipher_state_encrypt.k, - key_response=self.noise.noise_protocol.cipher_state_decrypt.k, - nonce_request=self.nonce_request, - nonce_response=self.nonce_response, - sync_bit_receive=self.sync_bit_receive, - sync_bit_send=self.sync_bit_send, - handshake_hash=self.handshake_hash, - ) - def read(self, session_id: int) -> t.Any: sid, msg_type, msg_data = self.read_and_decrypt() if sid != session_id: raise Exception("Received messsage on a different session.") - self.channel_database.save_channel(self) return self.mapping.decode(msg_type, msg_data) def write(self, session_id: int, msg: t.Any) -> None: msg_type, msg_data = self.mapping.encode(msg) self._encrypt_and_write(session_id, msg_type, msg_data) - self.channel_database.save_channel(self) def get_features(self) -> messages.Features: if not self._has_valid_channel: @@ -166,13 +130,13 @@ class ProtocolV2Channel(Channel): return (channel_id, device_properties) def _init_noise(self, randomness_static: bytes) -> None: - self.noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256") - self.noise.set_as_initiator() - self.noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static) + self._noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256") + self._noise.set_as_initiator() + self._noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static) prologue = bytes(self.device_properties) - self.noise.set_prologue(prologue) - self.noise.start_handshake() + self._noise.set_prologue(prologue) + self._noise.start_handshake() def _do_handshake( self, @@ -191,14 +155,10 @@ class ProtocolV2Channel(Channel): ) self._read_ack() self._read_handshake_completion_response() - self.key_request = self.noise.noise_protocol.cipher_state_encrypt.k - self.key_response = self.noise.noise_protocol.cipher_state_decrypt.k - self.nonce_request = 0 - self.nonce_response = 1 def _send_handshake_init_request(self) -> None: ha_init_req_header = MessageHeader(0, self.channel_id, 36) - host_ephemeral_pubkey = self.noise.write_message() + host_ephemeral_pubkey = self._noise.write_message() thp_io.write_payload_to_wire_and_add_checksum( self.transport, ha_init_req_header, host_ephemeral_pubkey @@ -219,7 +179,7 @@ class ProtocolV2Channel(Channel): "Received message is not a valid handshake init response message", err=True, ) - self.noise.read_message(payload) + self._noise.read_message(payload) return payload def _send_handshake_completion_request( @@ -237,7 +197,7 @@ class ProtocolV2Channel(Channel): host_pairing_credential=credential, ) ) - message2 = self.noise.write_message(payload=msg_data) + message2 = self._noise.write_message(payload=msg_data) ha_completion_req_header = MessageHeader( 0x12, @@ -249,7 +209,7 @@ class ProtocolV2Channel(Channel): ha_completion_req_header, message2, # encrypted_host_static_pubkey + encrypted_payload, ) - self.handshake_hash = self.noise.get_handshake_hash() + self.handshake_hash = self._noise.get_handshake_hash() def _read_handshake_completion_response(self) -> None: # Read handshake completion response, ignore payload as we do not care about the state @@ -259,9 +219,10 @@ class ProtocolV2Channel(Channel): "Received message is not a valid handshake completion response", err=True, ) - trezor_state = self.noise.decrypt(bytes(data)) + trezor_state = self._noise.decrypt(bytes(data)) # TODO handle trezor_state print("trezor state:", trezor_state) + assert trezor_state == b"\x00" or trezor_state == b"\x01" self._send_ack_1() def _do_pairing(self, helper_debug: DebugLink | None): @@ -305,7 +266,6 @@ class ProtocolV2Channel(Channel): message_data: bytes, ctrl_byte: int | None = None, ) -> None: - assert self.key_request is not None if ctrl_byte is None: ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send) @@ -315,7 +275,7 @@ class ProtocolV2Channel(Channel): msg_type = message_type.to_bytes(2, "big") data = sid + msg_type + message_data - encrypted_message = self.noise.encrypt(data) + encrypted_message = self._noise.encrypt(data) header = MessageHeader( ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH @@ -354,7 +314,7 @@ class ProtocolV2Channel(Channel): else: self._send_ack_0() - message = self.noise.decrypt(bytes(raw_payload)) + message = self._noise.decrypt(bytes(raw_payload)) session_id = message[0] message_type = message[1:3] message_data = message[3:] diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_thp.py index d2ae587e3a..7fbd899689 100644 --- a/tests/device_tests/thp/test_thp.py +++ b/tests/device_tests/thp/test_thp.py @@ -302,7 +302,7 @@ def test_credential_phase(client: Client) -> None: # Delete channel from the device by sending badly encrypted message # This is done to prevent channel replacement and trigerring of autoconnect false -> true - protocol.noise.noise_protocol.cipher_state_encrypt.n = 250 + protocol._noise.noise_protocol.cipher_state_encrypt.n = 250 protocol._send_message(ButtonAck()) with pytest.raises(Exception) as e: @@ -351,7 +351,7 @@ def test_credential_phase(client: Client) -> None: # Delete channel from the device by sending badly encrypted message # This is done to prevent channel replacement and trigerring of autoconnect false -> true - protocol.noise.noise_protocol.cipher_state_encrypt.n = 100 + protocol._noise.noise_protocol.cipher_state_encrypt.n = 100 protocol._send_message(ButtonAck()) with pytest.raises(Exception) as e: