mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-26 06:12:01 +00:00
feat: implement thp channel replacement
[no changelog]
This commit is contained in:
parent
ac71b8f957
commit
98e75f2e51
@ -393,9 +393,6 @@ async def _handle_credential_request(
|
||||
autoconnect=autoconnect,
|
||||
)
|
||||
credential = issue_credential(message.host_static_pubkey, credential_metadata)
|
||||
ctx.channel_ctx.channel_cache.set_host_static_pubkey(
|
||||
bytearray(message.host_static_pubkey)
|
||||
) # TODO This could raise an exception, should be handled?
|
||||
|
||||
return await ctx.call_any(
|
||||
ThpCredentialResponse(
|
||||
@ -416,6 +413,7 @@ async def _handle_end_request(
|
||||
|
||||
|
||||
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
|
||||
ctx.channel_ctx.replace_old_channels_with_the_same_host_pubkey()
|
||||
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
return ThpEndResponse()
|
||||
|
||||
|
@ -232,6 +232,41 @@ def create_or_replace_session(
|
||||
return _SESSIONS[index]
|
||||
|
||||
|
||||
def _migrate_sessions(old_channel: ChannelCache, new_channel: ChannelCache) -> None:
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == old_channel.channel_id:
|
||||
session.channel_id[:] = new_channel.channel_id
|
||||
|
||||
|
||||
def _replace_channel(old_channel: ChannelCache, new_channel: ChannelCache) -> None:
|
||||
_migrate_sessions(old_channel, new_channel)
|
||||
old_channel.clear()
|
||||
|
||||
|
||||
def conditionally_replace_channel(
|
||||
new_channel: ChannelCache, required_state: int, required_key: int
|
||||
) -> bool:
|
||||
"""Replaces "old channel" cache entry with a `new_channel` if two conditions are met:
|
||||
|
||||
1. The "old channel" is in a state `required_state`
|
||||
2. The "old channel" has the same value for `required_key` as the `new_channel`
|
||||
|
||||
|
||||
Returns: bool - whether any channel was replaced.
|
||||
"""
|
||||
was_any_channel_replaced: bool = False
|
||||
state = required_state.to_bytes(_CHANNEL_STATE_LENGTH, "big")
|
||||
for channel in _CHANNELS:
|
||||
if channel.channel_id == new_channel.channel_id:
|
||||
continue
|
||||
if channel.state == state and channel.get(required_key) == new_channel.get(
|
||||
required_key
|
||||
):
|
||||
_replace_channel(channel, new_channel)
|
||||
was_any_channel_replaced = True
|
||||
return was_any_channel_replaced
|
||||
|
||||
|
||||
def _get_usage_counter_and_increment() -> int:
|
||||
global _usage_counter
|
||||
_usage_counter += 1
|
||||
@ -343,7 +378,7 @@ def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None:
|
||||
channel.clear()
|
||||
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id != cid and session.session_id != sid:
|
||||
if session.channel_id != cid or session.session_id != sid:
|
||||
session.clear()
|
||||
else:
|
||||
s_last_usage = session.last_usage
|
||||
|
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import (
|
||||
CHANNEL_HANDSHAKE_HASH,
|
||||
CHANNEL_HOST_STATIC_PUBKEY,
|
||||
CHANNEL_KEY_RECEIVE,
|
||||
CHANNEL_KEY_SEND,
|
||||
CHANNEL_NONCE_RECEIVE,
|
||||
@ -13,6 +14,7 @@ from storage.cache_thp import (
|
||||
TAG_LENGTH,
|
||||
ChannelCache,
|
||||
clear_sessions_with_channel_id,
|
||||
conditionally_replace_channel,
|
||||
)
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.wire.errors import WireBufferError
|
||||
@ -111,6 +113,14 @@ class Channel:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("set_channel_state: ", state_to_str(state))
|
||||
|
||||
def replace_old_channels_with_the_same_host_pubkey(self) -> None:
|
||||
was_any_replaced = conditionally_replace_channel(
|
||||
new_channel=self.channel_cache,
|
||||
required_state=ChannelState.ENCRYPTED_TRANSPORT,
|
||||
required_key=CHANNEL_HOST_STATIC_PUBKEY,
|
||||
)
|
||||
log.debug(__name__, "Was any channel replaced? %s", str(was_any_replaced))
|
||||
|
||||
# READ and DECRYPT
|
||||
|
||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||
|
@ -321,6 +321,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
||||
|
||||
# key is decoded in handshake._handle_th2_crypto
|
||||
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
|
||||
ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey))
|
||||
|
||||
paired: bool = False
|
||||
trezor_state = _TREZOR_STATE_UNPAIRED
|
||||
@ -335,7 +336,6 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
||||
if paired:
|
||||
trezor_state = _TREZOR_STATE_PAIRED
|
||||
ctx.credential = credential
|
||||
ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey))
|
||||
else:
|
||||
ctx.credential = None
|
||||
except DataError as e:
|
||||
|
@ -282,17 +282,9 @@ class ProtocolV2Channel(Channel):
|
||||
h = _sha256_of_two(h, tag_of_empty_string)
|
||||
|
||||
# TODO: search for saved credentials
|
||||
if host_static_privkey is not None and credential is not None:
|
||||
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
|
||||
else:
|
||||
credential = None
|
||||
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||
temp_host_static_pubkey = curve25519.get_public_key(
|
||||
temp_host_static_privkey
|
||||
)
|
||||
host_static_privkey = temp_host_static_privkey
|
||||
host_static_pubkey = temp_host_static_pubkey
|
||||
if host_static_privkey is None:
|
||||
host_static_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
|
||||
|
@ -6,7 +6,8 @@ from hashlib import sha256
|
||||
import pytest
|
||||
import typing_extensions as tx
|
||||
|
||||
from trezorlib import protobuf
|
||||
from tests.common import get_test_address
|
||||
from trezorlib import exceptions, protobuf
|
||||
from trezorlib.client import ProtocolV2Channel
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.messages import (
|
||||
@ -47,18 +48,32 @@ def _prepare_protocol(client: Client) -> ProtocolV2Channel:
|
||||
protocol = client.protocol
|
||||
assert isinstance(protocol, ProtocolV2Channel)
|
||||
protocol._reset_sync_bits()
|
||||
return protocol
|
||||
|
||||
|
||||
def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2Channel:
|
||||
protocol = _prepare_protocol(client)
|
||||
protocol._do_channel_allocation()
|
||||
protocol._do_handshake()
|
||||
return protocol
|
||||
|
||||
|
||||
def _handle_pairing_request(client: Client, protocol: ProtocolV2Channel) -> None:
|
||||
protocol._send_message(ThpPairingRequest())
|
||||
def _prepare_protocol_for_pairing(
|
||||
client: Client, host_static_privkey: bytes | None = None
|
||||
) -> ProtocolV2Channel:
|
||||
protocol = _prepare_protocol(client)
|
||||
protocol._do_handshake(host_static_privkey=host_static_privkey)
|
||||
return protocol
|
||||
|
||||
|
||||
def _get_encrypted_transport_protocol(
|
||||
client: Client, host_static_privkey: bytes | None = None
|
||||
) -> ProtocolV2Channel:
|
||||
protocol = _prepare_protocol_for_pairing(
|
||||
client, host_static_privkey=host_static_privkey
|
||||
)
|
||||
protocol._do_pairing(client.debug)
|
||||
return protocol
|
||||
|
||||
|
||||
def _handle_pairing_request(
|
||||
client: Client, protocol: ProtocolV2Channel, host_name: str | None = None
|
||||
) -> None:
|
||||
protocol._send_message(ThpPairingRequest(host_name=host_name))
|
||||
button_req = protocol._read_message(ButtonRequest)
|
||||
assert button_req.name == "pairing_request"
|
||||
|
||||
@ -126,7 +141,7 @@ def test_handshake(client: Client) -> None:
|
||||
|
||||
def test_pairing_qr_code(client: Client) -> None:
|
||||
protocol = _prepare_protocol_for_pairing(client)
|
||||
_handle_pairing_request(client, protocol)
|
||||
_handle_pairing_request(client, protocol, "TestTrezor QrCode")
|
||||
protocol._send_message(
|
||||
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
|
||||
)
|
||||
@ -166,7 +181,7 @@ def test_pairing_qr_code(client: Client) -> None:
|
||||
def test_pairing_code_entry(client: Client) -> None:
|
||||
protocol = _prepare_protocol_for_pairing(client)
|
||||
|
||||
_handle_pairing_request(client, protocol)
|
||||
_handle_pairing_request(client, protocol, "TestTrezor CodeEntry")
|
||||
|
||||
protocol._send_message(
|
||||
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
|
||||
@ -232,9 +247,9 @@ def test_pairing_nfc(client: Client) -> None:
|
||||
protocol._has_valid_channel = True
|
||||
|
||||
|
||||
def _nfc_pairing(client: Client, protocol: ProtocolV2Channel):
|
||||
def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None:
|
||||
|
||||
_handle_pairing_request(client, protocol)
|
||||
_handle_pairing_request(client, protocol, "TestTrezor NfcPairing")
|
||||
|
||||
protocol._send_message(
|
||||
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)
|
||||
@ -273,7 +288,7 @@ def _nfc_pairing(client: Client, protocol: ProtocolV2Channel):
|
||||
assert tag_trezor_msg.tag == computed_tag
|
||||
|
||||
|
||||
def test_credential_phase(client: Client):
|
||||
def test_credential_phase(client: Client) -> None:
|
||||
protocol = _prepare_protocol_for_pairing(client)
|
||||
_nfc_pairing(client, protocol)
|
||||
|
||||
@ -324,3 +339,48 @@ def test_credential_phase(client: Client):
|
||||
protocol._do_handshake(credential_auto, host_static_privkey)
|
||||
protocol._send_message(ThpEndRequest())
|
||||
protocol._read_message(ThpEndResponse)
|
||||
|
||||
|
||||
@pytest.mark.setup_client(passphrase=True)
|
||||
def test_channel_replacement(client: Client) -> None:
|
||||
assert client.features.passphrase_protection is True
|
||||
|
||||
host_static_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_static_privkey_2 = curve25519.get_private_key(os.urandom(32))
|
||||
|
||||
assert host_static_privkey != host_static_privkey_2
|
||||
|
||||
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
|
||||
|
||||
session = client.get_session(passphrase="TREZOR", session_id=20)
|
||||
address = get_test_address(session)
|
||||
|
||||
session_2 = client.get_session(passphrase="ROZERT", session_id=30)
|
||||
address_2 = get_test_address(session_2)
|
||||
assert address != address_2
|
||||
|
||||
# create new channel using the same host_static_privkey
|
||||
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
|
||||
session_3 = client.get_session(passphrase="OKIDOKI", session_id=40)
|
||||
address_3 = get_test_address(session_3)
|
||||
assert address_3 != address_2
|
||||
|
||||
# test address on regenerated channel
|
||||
new_address = get_test_address(session)
|
||||
assert address == new_address
|
||||
new_address_3 = get_test_address(session_3)
|
||||
assert address_3 == new_address_3
|
||||
|
||||
# create new channel using different host_static_privkey
|
||||
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey_2)
|
||||
with pytest.raises(exceptions.TrezorFailure) as e_1:
|
||||
_ = get_test_address(session)
|
||||
assert str(e_1.value.message) == "Invalid session"
|
||||
|
||||
with pytest.raises(exceptions.TrezorFailure) as e_2:
|
||||
_ = get_test_address(session_3)
|
||||
assert str(e_2.value.message) == "Invalid session"
|
||||
|
||||
session_4 = client.get_session(passphrase="TREZOR", session_id=80)
|
||||
super_new_address = get_test_address(session_4)
|
||||
assert address == super_new_address
|
||||
|
Loading…
Reference in New Issue
Block a user