mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-27 06:42:02 +00:00
feat: implement thp channel replacement
[no changelog]
This commit is contained in:
parent
a3f2929938
commit
1cf789d82d
@ -393,9 +393,6 @@ async def _handle_credential_request(
|
|||||||
autoconnect=autoconnect,
|
autoconnect=autoconnect,
|
||||||
)
|
)
|
||||||
credential = issue_credential(message.host_static_pubkey, credential_metadata)
|
credential = issue_credential(message.host_static_pubkey, credential_metadata)
|
||||||
ctx.channel_ctx.channel_cache.set_host_static_pubkey(
|
|
||||||
bytearray(message.host_static_pubkey)
|
|
||||||
) # TODO This could raise an exception, should be handled?
|
|
||||||
|
|
||||||
return await ctx.call_any(
|
return await ctx.call_any(
|
||||||
ThpCredentialResponse(
|
ThpCredentialResponse(
|
||||||
@ -416,6 +413,7 @@ async def _handle_end_request(
|
|||||||
|
|
||||||
|
|
||||||
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
|
async def _end_pairing(ctx: PairingContext) -> ThpEndResponse:
|
||||||
|
ctx.channel_ctx.replace_old_channels_with_the_same_host_pubkey()
|
||||||
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||||
return ThpEndResponse()
|
return ThpEndResponse()
|
||||||
|
|
||||||
|
@ -232,6 +232,41 @@ def create_or_replace_session(
|
|||||||
return _SESSIONS[index]
|
return _SESSIONS[index]
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_sessions(old_channel: ChannelCache, new_channel: ChannelCache) -> None:
|
||||||
|
for session in _SESSIONS:
|
||||||
|
if session.channel_id == old_channel.channel_id:
|
||||||
|
session.channel_id[:] = new_channel.channel_id
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_channel(old_channel: ChannelCache, new_channel: ChannelCache) -> None:
|
||||||
|
_migrate_sessions(old_channel, new_channel)
|
||||||
|
old_channel.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def conditionally_replace_channel(
|
||||||
|
new_channel: ChannelCache, required_state: int, required_key: int
|
||||||
|
) -> bool:
|
||||||
|
"""Replaces "old channel" cache entry with a `new_channel` if two conditions are met:
|
||||||
|
|
||||||
|
1. The "old channel" is in a state `required_state`
|
||||||
|
2. The "old channel" has the same value for `required_key` as the `new_channel`
|
||||||
|
|
||||||
|
|
||||||
|
Returns: bool - whether any channel was replaced.
|
||||||
|
"""
|
||||||
|
was_any_channel_replaced: bool = False
|
||||||
|
state = required_state.to_bytes(_CHANNEL_STATE_LENGTH, "big")
|
||||||
|
for channel in _CHANNELS:
|
||||||
|
if channel.channel_id == new_channel.channel_id:
|
||||||
|
continue
|
||||||
|
if channel.state == state and channel.get(required_key) == new_channel.get(
|
||||||
|
required_key
|
||||||
|
):
|
||||||
|
_replace_channel(channel, new_channel)
|
||||||
|
was_any_channel_replaced = True
|
||||||
|
return was_any_channel_replaced
|
||||||
|
|
||||||
|
|
||||||
def _get_usage_counter_and_increment() -> int:
|
def _get_usage_counter_and_increment() -> int:
|
||||||
global _usage_counter
|
global _usage_counter
|
||||||
_usage_counter += 1
|
_usage_counter += 1
|
||||||
@ -343,7 +378,7 @@ def clear_all_except_one_session_keys(excluded: Tuple[bytes, bytes]) -> None:
|
|||||||
channel.clear()
|
channel.clear()
|
||||||
|
|
||||||
for session in _SESSIONS:
|
for session in _SESSIONS:
|
||||||
if session.channel_id != cid and session.session_id != sid:
|
if session.channel_id != cid or session.session_id != sid:
|
||||||
session.clear()
|
session.clear()
|
||||||
else:
|
else:
|
||||||
s_last_usage = session.last_usage
|
s_last_usage = session.last_usage
|
||||||
|
@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from storage.cache_common import (
|
from storage.cache_common import (
|
||||||
CHANNEL_HANDSHAKE_HASH,
|
CHANNEL_HANDSHAKE_HASH,
|
||||||
|
CHANNEL_HOST_STATIC_PUBKEY,
|
||||||
CHANNEL_KEY_RECEIVE,
|
CHANNEL_KEY_RECEIVE,
|
||||||
CHANNEL_KEY_SEND,
|
CHANNEL_KEY_SEND,
|
||||||
CHANNEL_NONCE_RECEIVE,
|
CHANNEL_NONCE_RECEIVE,
|
||||||
@ -13,6 +14,7 @@ from storage.cache_thp import (
|
|||||||
TAG_LENGTH,
|
TAG_LENGTH,
|
||||||
ChannelCache,
|
ChannelCache,
|
||||||
clear_sessions_with_channel_id,
|
clear_sessions_with_channel_id,
|
||||||
|
conditionally_replace_channel,
|
||||||
)
|
)
|
||||||
from trezor import log, loop, protobuf, utils, workflow
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
from trezor.wire.errors import WireBufferError
|
from trezor.wire.errors import WireBufferError
|
||||||
@ -111,6 +113,14 @@ class Channel:
|
|||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("set_channel_state: ", state_to_str(state))
|
self._log("set_channel_state: ", state_to_str(state))
|
||||||
|
|
||||||
|
def replace_old_channels_with_the_same_host_pubkey(self) -> None:
|
||||||
|
was_any_replaced = conditionally_replace_channel(
|
||||||
|
new_channel=self.channel_cache,
|
||||||
|
required_state=ChannelState.ENCRYPTED_TRANSPORT,
|
||||||
|
required_key=CHANNEL_HOST_STATIC_PUBKEY,
|
||||||
|
)
|
||||||
|
log.debug(__name__, "Was any channel replaced? %s", str(was_any_replaced))
|
||||||
|
|
||||||
# READ and DECRYPT
|
# READ and DECRYPT
|
||||||
|
|
||||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||||
|
@ -321,6 +321,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
|
|
||||||
# key is decoded in handshake._handle_th2_crypto
|
# key is decoded in handshake._handle_th2_crypto
|
||||||
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
|
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
|
||||||
|
ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey))
|
||||||
|
|
||||||
paired: bool = False
|
paired: bool = False
|
||||||
trezor_state = _TREZOR_STATE_UNPAIRED
|
trezor_state = _TREZOR_STATE_UNPAIRED
|
||||||
@ -335,7 +336,6 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
if paired:
|
if paired:
|
||||||
trezor_state = _TREZOR_STATE_PAIRED
|
trezor_state = _TREZOR_STATE_PAIRED
|
||||||
ctx.credential = credential
|
ctx.credential = credential
|
||||||
ctx.channel_cache.set_host_static_pubkey(bytearray(host_static_pubkey))
|
|
||||||
else:
|
else:
|
||||||
ctx.credential = None
|
ctx.credential = None
|
||||||
except DataError as e:
|
except DataError as e:
|
||||||
|
@ -282,17 +282,9 @@ class ProtocolV2Channel(Channel):
|
|||||||
h = _sha256_of_two(h, tag_of_empty_string)
|
h = _sha256_of_two(h, tag_of_empty_string)
|
||||||
|
|
||||||
# TODO: search for saved credentials
|
# TODO: search for saved credentials
|
||||||
if host_static_privkey is not None and credential is not None:
|
if host_static_privkey is None:
|
||||||
|
host_static_privkey = curve25519.get_private_key(os.urandom(32))
|
||||||
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
|
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
|
||||||
else:
|
|
||||||
credential = None
|
|
||||||
zeroes_32 = int.to_bytes(0, 32, "little")
|
|
||||||
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
|
||||||
temp_host_static_pubkey = curve25519.get_public_key(
|
|
||||||
temp_host_static_privkey
|
|
||||||
)
|
|
||||||
host_static_privkey = temp_host_static_privkey
|
|
||||||
host_static_pubkey = temp_host_static_pubkey
|
|
||||||
|
|
||||||
aes_ctx = AESGCM(k)
|
aes_ctx = AESGCM(k)
|
||||||
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
|
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
|
||||||
|
@ -6,7 +6,8 @@ from hashlib import sha256
|
|||||||
import pytest
|
import pytest
|
||||||
import typing_extensions as tx
|
import typing_extensions as tx
|
||||||
|
|
||||||
from trezorlib import protobuf
|
from tests.common import get_test_address
|
||||||
|
from trezorlib import exceptions, protobuf
|
||||||
from trezorlib.client import ProtocolV2Channel
|
from trezorlib.client import ProtocolV2Channel
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
from trezorlib.messages import (
|
from trezorlib.messages import (
|
||||||
@ -47,18 +48,32 @@ def _prepare_protocol(client: Client) -> ProtocolV2Channel:
|
|||||||
protocol = client.protocol
|
protocol = client.protocol
|
||||||
assert isinstance(protocol, ProtocolV2Channel)
|
assert isinstance(protocol, ProtocolV2Channel)
|
||||||
protocol._reset_sync_bits()
|
protocol._reset_sync_bits()
|
||||||
return protocol
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2Channel:
|
|
||||||
protocol = _prepare_protocol(client)
|
|
||||||
protocol._do_channel_allocation()
|
protocol._do_channel_allocation()
|
||||||
protocol._do_handshake()
|
|
||||||
return protocol
|
return protocol
|
||||||
|
|
||||||
|
|
||||||
def _handle_pairing_request(client: Client, protocol: ProtocolV2Channel) -> None:
|
def _prepare_protocol_for_pairing(
|
||||||
protocol._send_message(ThpPairingRequest())
|
client: Client, host_static_privkey: bytes | None = None
|
||||||
|
) -> ProtocolV2Channel:
|
||||||
|
protocol = _prepare_protocol(client)
|
||||||
|
protocol._do_handshake(host_static_privkey=host_static_privkey)
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
|
||||||
|
def _get_encrypted_transport_protocol(
|
||||||
|
client: Client, host_static_privkey: bytes | None = None
|
||||||
|
) -> ProtocolV2Channel:
|
||||||
|
protocol = _prepare_protocol_for_pairing(
|
||||||
|
client, host_static_privkey=host_static_privkey
|
||||||
|
)
|
||||||
|
protocol._do_pairing(client.debug)
|
||||||
|
return protocol
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_pairing_request(
|
||||||
|
client: Client, protocol: ProtocolV2Channel, host_name: str | None = None
|
||||||
|
) -> None:
|
||||||
|
protocol._send_message(ThpPairingRequest(host_name=host_name))
|
||||||
button_req = protocol._read_message(ButtonRequest)
|
button_req = protocol._read_message(ButtonRequest)
|
||||||
assert button_req.name == "pairing_request"
|
assert button_req.name == "pairing_request"
|
||||||
|
|
||||||
@ -126,7 +141,7 @@ def test_handshake(client: Client) -> None:
|
|||||||
|
|
||||||
def test_pairing_qr_code(client: Client) -> None:
|
def test_pairing_qr_code(client: Client) -> None:
|
||||||
protocol = _prepare_protocol_for_pairing(client)
|
protocol = _prepare_protocol_for_pairing(client)
|
||||||
_handle_pairing_request(client, protocol)
|
_handle_pairing_request(client, protocol, "TestTrezor QrCode")
|
||||||
protocol._send_message(
|
protocol._send_message(
|
||||||
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
|
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
|
||||||
)
|
)
|
||||||
@ -166,7 +181,7 @@ def test_pairing_qr_code(client: Client) -> None:
|
|||||||
def test_pairing_code_entry(client: Client) -> None:
|
def test_pairing_code_entry(client: Client) -> None:
|
||||||
protocol = _prepare_protocol_for_pairing(client)
|
protocol = _prepare_protocol_for_pairing(client)
|
||||||
|
|
||||||
_handle_pairing_request(client, protocol)
|
_handle_pairing_request(client, protocol, "TestTrezor CodeEntry")
|
||||||
|
|
||||||
protocol._send_message(
|
protocol._send_message(
|
||||||
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
|
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
|
||||||
@ -232,9 +247,9 @@ def test_pairing_nfc(client: Client) -> None:
|
|||||||
protocol._has_valid_channel = True
|
protocol._has_valid_channel = True
|
||||||
|
|
||||||
|
|
||||||
def _nfc_pairing(client: Client, protocol: ProtocolV2Channel):
|
def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None:
|
||||||
|
|
||||||
_handle_pairing_request(client, protocol)
|
_handle_pairing_request(client, protocol, "TestTrezor NfcPairing")
|
||||||
|
|
||||||
protocol._send_message(
|
protocol._send_message(
|
||||||
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)
|
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)
|
||||||
@ -273,7 +288,7 @@ def _nfc_pairing(client: Client, protocol: ProtocolV2Channel):
|
|||||||
assert tag_trezor_msg.tag == computed_tag
|
assert tag_trezor_msg.tag == computed_tag
|
||||||
|
|
||||||
|
|
||||||
def test_credential_phase(client: Client):
|
def test_credential_phase(client: Client) -> None:
|
||||||
protocol = _prepare_protocol_for_pairing(client)
|
protocol = _prepare_protocol_for_pairing(client)
|
||||||
_nfc_pairing(client, protocol)
|
_nfc_pairing(client, protocol)
|
||||||
|
|
||||||
@ -324,3 +339,48 @@ def test_credential_phase(client: Client):
|
|||||||
protocol._do_handshake(credential_auto, host_static_privkey)
|
protocol._do_handshake(credential_auto, host_static_privkey)
|
||||||
protocol._send_message(ThpEndRequest())
|
protocol._send_message(ThpEndRequest())
|
||||||
protocol._read_message(ThpEndResponse)
|
protocol._read_message(ThpEndResponse)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.setup_client(passphrase=True)
|
||||||
|
def test_channel_replacement(client: Client) -> None:
|
||||||
|
assert client.features.passphrase_protection is True
|
||||||
|
|
||||||
|
host_static_privkey = curve25519.get_private_key(os.urandom(32))
|
||||||
|
host_static_privkey_2 = curve25519.get_private_key(os.urandom(32))
|
||||||
|
|
||||||
|
assert host_static_privkey != host_static_privkey_2
|
||||||
|
|
||||||
|
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
|
||||||
|
|
||||||
|
session = client.get_session(passphrase="TREZOR", session_id=20)
|
||||||
|
address = get_test_address(session)
|
||||||
|
|
||||||
|
session_2 = client.get_session(passphrase="ROZERT", session_id=30)
|
||||||
|
address_2 = get_test_address(session_2)
|
||||||
|
assert address != address_2
|
||||||
|
|
||||||
|
# create new channel using the same host_static_privkey
|
||||||
|
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey)
|
||||||
|
session_3 = client.get_session(passphrase="OKIDOKI", session_id=40)
|
||||||
|
address_3 = get_test_address(session_3)
|
||||||
|
assert address_3 != address_2
|
||||||
|
|
||||||
|
# test address on regenerated channel
|
||||||
|
new_address = get_test_address(session)
|
||||||
|
assert address == new_address
|
||||||
|
new_address_3 = get_test_address(session_3)
|
||||||
|
assert address_3 == new_address_3
|
||||||
|
|
||||||
|
# create new channel using different host_static_privkey
|
||||||
|
client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey_2)
|
||||||
|
with pytest.raises(exceptions.TrezorFailure) as e_1:
|
||||||
|
_ = get_test_address(session)
|
||||||
|
assert str(e_1.value.message) == "Invalid session"
|
||||||
|
|
||||||
|
with pytest.raises(exceptions.TrezorFailure) as e_2:
|
||||||
|
_ = get_test_address(session_3)
|
||||||
|
assert str(e_2.value.message) == "Invalid session"
|
||||||
|
|
||||||
|
session_4 = client.get_session(passphrase="TREZOR", session_id=80)
|
||||||
|
super_new_address = get_test_address(session_4)
|
||||||
|
assert address == super_new_address
|
||||||
|
Loading…
Reference in New Issue
Block a user