1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-26 14:22:06 +00:00

feat: implement thp channel replacement

[no changelog]
This commit is contained in:
M1nd3r 2025-02-24 17:21:25 +01:00
parent ac71b8f957
commit 98e75f2e51
6 changed files with 125 additions and 30 deletions

View File

@ -393,9 +393,6 @@ async def _handle_credential_request(
autoconnect=autoconnect,
)
credential = issue_credential(message.host_static_pubkey, credential_metadata)
ctx.channel_ctx.channel_cache.set_host_static_pubkey(
bytearray(message.host_static_pubkey)
) # TODO This could raise an exception, should be handled?
return await ctx.call_any(
ThpCredentialResponse(
@ -416,6 +413,7 @@ async def _handle_end_request(
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
ctx.channel_ctx.replace_old_channels_with_the_same_host_pubkey()
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()

View File

@ -232,6 +232,41 @@ def create_or_replace_session(
return _SESSIONS[index]
def _migrate_sessions(old_channel: ChannelCache, new_channel: ChannelCache) -> None:
for session in _SESSIONS:
if session.channel_id == old_channel.channel_id:
session.channel_id[:] = new_channel.channel_id
def _replace_channel(old_channel: ChannelCache, new_channel: ChannelCache) -> None:
_migrate_sessions(old_channel, new_channel)
old_channel.clear()
def conditionally_replace_channel(
new_channel: ChannelCache, required_state: int, required_key: int
) -> bool:
"""Replaces "old channel" cache entry with a `new_channel` if two conditions are met:
1. The "old channel" is in a state `required_state`
2. The "old channel" has the same value for `required_key` as the `new_channel`
Returns: bool - whether any channel was replaced.
"""
was_any_channel_replaced: bool = False
state = required_state.to_bytes(_CHANNEL_STATE_LENGTH, "big")
for channel in _CHANNELS:
if channel.channel_id == new_channel.channel_id:
continue
if channel.state == state and channel.get(required_key) == new_channel.get(
required_key
):
_replace_channel(channel, new_channel)
was_any_channel_replaced = True
return was_any_channel_replaced
def _get_usage_counter_and_increment() -> int:
global _usage_counter
_usage_counter += 1
@ -343,7 +378,7 @@ def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None:
channel.clear()
for session in _SESSIONS:
if session.channel_id != cid and session.session_id != sid:
if session.channel_id != cid or session.session_id != sid:
session.clear()
else:
s_last_usage = session.last_usage

View File

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
from storage.cache_common import (
CHANNEL_HANDSHAKE_HASH,
CHANNEL_HOST_STATIC_PUBKEY,
CHANNEL_KEY_RECEIVE,
CHANNEL_KEY_SEND,
CHANNEL_NONCE_RECEIVE,
@ -13,6 +14,7 @@ from storage.cache_thp import (
TAG_LENGTH,
ChannelCache,
clear_sessions_with_channel_id,
conditionally_replace_channel,
)
from trezor import log, loop, protobuf, utils, workflow
from trezor.wire.errors import WireBufferError
@ -111,6 +113,14 @@ class Channel:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("set_channel_state: ", state_to_str(state))
def replace_old_channels_with_the_same_host_pubkey(self) -> None:
was_any_replaced = conditionally_replace_channel(
new_channel=self.channel_cache,
required_state=ChannelState.ENCRYPTED_TRANSPORT,
required_key=CHANNEL_HOST_STATIC_PUBKEY,
)
log.debug(__name__, "Was any channel replaced? %s", str(was_any_replaced))
# READ and DECRYPT
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:

View File

@ -321,6 +321,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
# key is decoded in handshake._handle_th2_crypto
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey))
paired: bool = False
trezor_state = _TREZOR_STATE_UNPAIRED
@ -335,7 +336,6 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
if paired:
trezor_state = _TREZOR_STATE_PAIRED
ctx.credential = credential
ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey))
else:
ctx.credential = None
except DataError as e:

View File

@ -282,17 +282,9 @@ class ProtocolV2Channel(Channel):
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials
if host_static_privkey is not None and credential is not None:
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
else:
credential = None
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(
temp_host_static_privkey
)
host_static_privkey = temp_host_static_privkey
host_static_pubkey = temp_host_static_pubkey
if host_static_privkey is None:
host_static_privkey = curve25519.get_private_key(os.urandom(32))
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)

View File

@ -6,7 +6,8 @@ from hashlib import sha256
import pytest
import typing_extensions as tx
from trezorlib import protobuf
from tests.common import get_test_address
from trezorlib import exceptions, protobuf
from trezorlib.client import ProtocolV2Channel
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.messages import (
@ -47,18 +48,32 @@ def _prepare_protocol(client: Client) -> ProtocolV2Channel:
protocol = client.protocol
assert isinstance(protocol, ProtocolV2Channel)
protocol._reset_sync_bits()
return protocol
def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2Channel:
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake()
return protocol
def _handle_pairing_request(client: Client, protocol: ProtocolV2Channel) -> None:
protocol._send_message(ThpPairingRequest())
def _prepare_protocol_for_pairing(
client: Client, host_static_privkey: bytes | None = None
) -> ProtocolV2Channel:
protocol = _prepare_protocol(client)
protocol._do_handshake(host_static_privkey=host_static_privkey)
return protocol
def _get_encrypted_transport_protocol(
client: Client, host_static_privkey: bytes | None = None
) -> ProtocolV2Channel:
protocol = _prepare_protocol_for_pairing(
client, host_static_privkey=host_static_privkey
)
protocol._do_pairing(client.debug)
return protocol
def _handle_pairing_request(
client: Client, protocol: ProtocolV2Channel, host_name: str | None = None
) -> None:
protocol._send_message(ThpPairingRequest(host_name=host_name))
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "pairing_request"
@ -126,7 +141,7 @@ def test_handshake(client: Client) -> None:
def test_pairing_qr_code(client: Client) -> None:
protocol = _prepare_protocol_for_pairing(client)
_handle_pairing_request(client, protocol)
_handle_pairing_request(client, protocol, "TestTrezor QrCode")
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
)
@ -166,7 +181,7 @@ def test_pairing_qr_code(client: Client) -> None:
def test_pairing_code_entry(client: Client) -> None:
protocol = _prepare_protocol_for_pairing(client)
_handle_pairing_request(client, protocol)
_handle_pairing_request(client, protocol, "TestTrezor CodeEntry")
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
@ -232,9 +247,9 @@ def test_pairing_nfc(client: Client) -> None:
protocol._has_valid_channel = True
def _nfc_pairing(client: Client, protocol: ProtocolV2Channel):
def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None:
_handle_pairing_request(client, protocol)
_handle_pairing_request(client, protocol, "TestTrezor NfcPairing")
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)
@ -273,7 +288,7 @@ def _nfc_pairing(client: Client, protocol: ProtocolV2Channel):
assert tag_trezor_msg.tag == computed_tag
def test_credential_phase(client: Client):
def test_credential_phase(client: Client) -> None:
protocol = _prepare_protocol_for_pairing(client)
_nfc_pairing(client, protocol)
@ -324,3 +339,48 @@ def test_credential_phase(client: Client):
protocol._do_handshake(credential_auto, host_static_privkey)
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
@pytest.mark.setup_client(passphrase=True)
def test_channel_replacement(client: Client) -> None:
assert client.features.passphrase_protection is True
host_static_privkey = curve25519.get_private_key(os.urandom(32))
host_static_privkey_2 = curve25519.get_private_key(os.urandom(32))
assert host_static_privkey != host_static_privkey_2
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
session = client.get_session(passphrase="TREZOR", session_id=20)
address = get_test_address(session)
session_2 = client.get_session(passphrase="ROZERT", session_id=30)
address_2 = get_test_address(session_2)
assert address != address_2
# create new channel using the same host_static_privkey
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
session_3 = client.get_session(passphrase="OKIDOKI", session_id=40)
address_3 = get_test_address(session_3)
assert address_3 != address_2
# test address on regenerated channel
new_address = get_test_address(session)
assert address == new_address
new_address_3 = get_test_address(session_3)
assert address_3 == new_address_3
# create new channel using different host_static_privkey
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey_2)
with pytest.raises(exceptions.TrezorFailure) as e_1:
_ = get_test_address(session)
assert str(e_1.value.message) == "Invalid session"
with pytest.raises(exceptions.TrezorFailure) as e_2:
_ = get_test_address(session_3)
assert str(e_2.value.message) == "Invalid session"
session_4 = client.get_session(passphrase="TREZOR", session_id=80)
super_new_address = get_test_address(session_4)
assert address == super_new_address