From 98e75f2e5140877037f98b653f3b4a39fa5cee31 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 24 Feb 2025 17:21:25 +0100 Subject: [PATCH] feat: implement thp channel replacement [no changelog] --- core/src/apps/thp/pairing.py | 4 +- core/src/storage/cache_thp.py | 37 +++++++- core/src/trezor/wire/thp/channel.py | 10 +++ .../wire/thp/received_message_handler.py | 2 +- .../trezorlib/transport/thp/protocol_v2.py | 14 +-- tests/device_tests/thp/test_thp.py | 88 ++++++++++++++++--- 6 files changed, 125 insertions(+), 30 deletions(-) diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 8171e34481..aa08041f4b 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -393,9 +393,6 @@ async def _handle_credential_request( autoconnect=autoconnect, ) credential = issue_credential(message.host_static_pubkey, credential_metadata) - ctx.channel_ctx.channel_cache.set_host_static_pubkey( - bytearray(message.host_static_pubkey) - ) # TODO This could raise an exception, should be handled? return await ctx.call_any( ThpCredentialResponse( @@ -416,6 +413,7 @@ async def _handle_end_request( async def _end_pairing(ctx: PairingContext) -> ThpEndResponse: + ctx.channel_ctx.replace_old_channels_with_the_same_host_pubkey() ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) return ThpEndResponse() diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 235983be02..6f5019bc80 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -232,6 +232,41 @@ def create_or_replace_session( return _SESSIONS[index] +def _migrate_sessions(old_channel: ChannelCache, new_channel: ChannelCache) -> None: + for session in _SESSIONS: + if session.channel_id == old_channel.channel_id: + session.channel_id[:] = new_channel.channel_id + + +def _replace_channel(old_channel: ChannelCache, new_channel: ChannelCache) -> None: + _migrate_sessions(old_channel, new_channel) + old_channel.clear() + + +def conditionally_replace_channel( + new_channel: ChannelCache, required_state: int, required_key: int +) -> bool: + """Replaces "old channel" cache entry with a `new_channel` if two conditions are met: + + 1. The "old channel" is in a state `required_state` + 2. The "old channel" has the same value for `required_key` as the `new_channel` + + + Returns: bool - whether any channel was replaced. + """ + was_any_channel_replaced: bool = False + state = required_state.to_bytes(_CHANNEL_STATE_LENGTH, "big") + for channel in _CHANNELS: + if channel.channel_id == new_channel.channel_id: + continue + if channel.state == state and channel.get(required_key) == new_channel.get( + required_key + ): + _replace_channel(channel, new_channel) + was_any_channel_replaced = True + return was_any_channel_replaced + + def _get_usage_counter_and_increment() -> int: global _usage_counter _usage_counter += 1 @@ -343,7 +378,7 @@ def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None: channel.clear() for session in _SESSIONS: - if session.channel_id != cid and session.session_id != sid: + if session.channel_id != cid or session.session_id != sid: session.clear() else: s_last_usage = session.last_usage diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 39baa27e49..953ddf3879 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from storage.cache_common import ( CHANNEL_HANDSHAKE_HASH, + CHANNEL_HOST_STATIC_PUBKEY, CHANNEL_KEY_RECEIVE, CHANNEL_KEY_SEND, CHANNEL_NONCE_RECEIVE, @@ -13,6 +14,7 @@ from storage.cache_thp import ( TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id, + conditionally_replace_channel, ) from trezor import log, loop, protobuf, utils, workflow from trezor.wire.errors import WireBufferError @@ -111,6 +113,14 @@ class Channel: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("set_channel_state: ", state_to_str(state)) + def replace_old_channels_with_the_same_host_pubkey(self) -> None: + was_any_replaced = conditionally_replace_channel( + new_channel=self.channel_cache, + required_state=ChannelState.ENCRYPTED_TRANSPORT, + required_key=CHANNEL_HOST_STATIC_PUBKEY, + ) + log.debug(__name__, "Was any channel replaced? %s", str(was_any_replaced)) + # READ and DECRYPT def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 4f6a7afc3b..022db83701 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -321,6 +321,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) - # key is decoded in handshake._handle_th2_crypto host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH] + ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey)) paired: bool = False trezor_state = _TREZOR_STATE_UNPAIRED @@ -335,7 +336,6 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) - if paired: trezor_state = _TREZOR_STATE_PAIRED ctx.credential = credential - ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey)) else: ctx.credential = None except DataError as e: diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index dc07772595..6c5021fedc 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -282,17 +282,9 @@ class ProtocolV2Channel(Channel): h = _sha256_of_two(h, tag_of_empty_string) # TODO: search for saved credentials - if host_static_privkey is not None and credential is not None: - host_static_pubkey = curve25519.get_public_key(host_static_privkey) - else: - credential = None - zeroes_32 = int.to_bytes(0, 32, "little") - temp_host_static_privkey = curve25519.get_private_key(zeroes_32) - temp_host_static_pubkey = curve25519.get_public_key( - temp_host_static_privkey - ) - host_static_privkey = temp_host_static_privkey - host_static_pubkey = temp_host_static_pubkey + if host_static_privkey is None: + host_static_privkey = curve25519.get_private_key(os.urandom(32)) + host_static_pubkey = curve25519.get_public_key(host_static_privkey) aes_ctx = AESGCM(k) encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h) diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_thp.py index a909bf3621..8519ed21d9 100644 --- a/tests/device_tests/thp/test_thp.py +++ b/tests/device_tests/thp/test_thp.py @@ -6,7 +6,8 @@ from hashlib import sha256 import pytest import typing_extensions as tx -from trezorlib import protobuf +from tests.common import get_test_address +from trezorlib import exceptions, protobuf from trezorlib.client import ProtocolV2Channel from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.messages import ( @@ -47,18 +48,32 @@ def _prepare_protocol(client: Client) -> ProtocolV2Channel: protocol = client.protocol assert isinstance(protocol, ProtocolV2Channel) protocol._reset_sync_bits() - return protocol - - -def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2Channel: - protocol = _prepare_protocol(client) protocol._do_channel_allocation() - protocol._do_handshake() return protocol -def _handle_pairing_request(client: Client, protocol: ProtocolV2Channel) -> None: - protocol._send_message(ThpPairingRequest()) +def _prepare_protocol_for_pairing( + client: Client, host_static_privkey: bytes | None = None +) -> ProtocolV2Channel: + protocol = _prepare_protocol(client) + protocol._do_handshake(host_static_privkey=host_static_privkey) + return protocol + + +def _get_encrypted_transport_protocol( + client: Client, host_static_privkey: bytes | None = None +) -> ProtocolV2Channel: + protocol = _prepare_protocol_for_pairing( + client, host_static_privkey=host_static_privkey + ) + protocol._do_pairing(client.debug) + return protocol + + +def _handle_pairing_request( + client: Client, protocol: ProtocolV2Channel, host_name: str | None = None +) -> None: + protocol._send_message(ThpPairingRequest(host_name=host_name)) button_req = protocol._read_message(ButtonRequest) assert button_req.name == "pairing_request" @@ -126,7 +141,7 @@ def test_handshake(client: Client) -> None: def test_pairing_qr_code(client: Client) -> None: protocol = _prepare_protocol_for_pairing(client) - _handle_pairing_request(client, protocol) + _handle_pairing_request(client, protocol, "TestTrezor QrCode") protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) ) @@ -166,7 +181,7 @@ def test_pairing_qr_code(client: Client) -> None: def test_pairing_code_entry(client: Client) -> None: protocol = _prepare_protocol_for_pairing(client) - _handle_pairing_request(client, protocol) + _handle_pairing_request(client, protocol, "TestTrezor CodeEntry") protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) @@ -232,9 +247,9 @@ def test_pairing_nfc(client: Client) -> None: protocol._has_valid_channel = True -def _nfc_pairing(client: Client, protocol: ProtocolV2Channel): +def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None: - _handle_pairing_request(client, protocol) + _handle_pairing_request(client, protocol, "TestTrezor NfcPairing") protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC) @@ -273,7 +288,7 @@ def _nfc_pairing(client: Client, protocol: ProtocolV2Channel): assert tag_trezor_msg.tag == computed_tag -def test_credential_phase(client: Client): +def test_credential_phase(client: Client) -> None: protocol = _prepare_protocol_for_pairing(client) _nfc_pairing(client, protocol) @@ -324,3 +339,48 @@ def test_credential_phase(client: Client): protocol._do_handshake(credential_auto, host_static_privkey) protocol._send_message(ThpEndRequest()) protocol._read_message(ThpEndResponse) + + +@pytest.mark.setup_client(passphrase=True) +def test_channel_replacement(client: Client) -> None: + assert client.features.passphrase_protection is True + + host_static_privkey = curve25519.get_private_key(os.urandom(32)) + host_static_privkey_2 = curve25519.get_private_key(os.urandom(32)) + + assert host_static_privkey != host_static_privkey_2 + + client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey) + + session = client.get_session(passphrase="TREZOR", session_id=20) + address = get_test_address(session) + + session_2 = client.get_session(passphrase="ROZERT", session_id=30) + address_2 = get_test_address(session_2) + assert address != address_2 + + # create new channel using the same host_static_privkey + client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey) + session_3 = client.get_session(passphrase="OKIDOKI", session_id=40) + address_3 = get_test_address(session_3) + assert address_3 != address_2 + + # test address on regenerated channel + new_address = get_test_address(session) + assert address == new_address + new_address_3 = get_test_address(session_3) + assert address_3 == new_address_3 + + # create new channel using different host_static_privkey + client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey_2) + with pytest.raises(exceptions.TrezorFailure) as e_1: + _ = get_test_address(session) + assert str(e_1.value.message) == "Invalid session" + + with pytest.raises(exceptions.TrezorFailure) as e_2: + _ = get_test_address(session_3) + assert str(e_2.value.message) == "Invalid session" + + session_4 = client.get_session(passphrase="TREZOR", session_id=80) + super_new_address = get_test_address(session_4) + assert address == super_new_address