1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-09 22:22:38 +00:00

feat(core, python): implement autoconnect credentials and add test

[no changelog]
This commit is contained in:
M1nd3r 2025-01-31 12:41:45 +01:00
parent d50ec53136
commit 7c154f441e
8 changed files with 224 additions and 55 deletions

View File

@ -63,17 +63,26 @@ def issue_credential(
return credential_raw
def validate_credential(
def decode_credential(
encoded_pairing_credential_message: bytes,
) -> ThpPairingCredential:
"""
Decode a protobuf encoded pairing credential.
"""
expected_type = protobuf.type_for_name("ThpPairingCredential")
credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type)
assert ThpPairingCredential.is_type_of(credential)
return credential
def validate_credential(
credential: ThpPairingCredential,
host_static_pubkey: bytes,
) -> bool:
"""
Validate a pairing credential binded to the provided host static public key.
"""
cred_auth_key = derive_cred_auth_key()
expected_type = protobuf.type_for_name("ThpPairingCredential")
credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type)
assert ThpPairingCredential.is_type_of(credential)
proto_msg = ThpAuthenticatedCredentialData(
host_static_pubkey=host_static_pubkey,
cred_metadata=credential.cred_metadata,
@ -83,6 +92,27 @@ def validate_credential(
return mac == credential.mac
def decode_and_validate_credential(
encoded_pairing_credential_message: bytes,
host_static_pubkey: bytes,
) -> bool:
"""
Decode a protobuf encoded pairing credential and validate it
binded to the provided host static public key.
"""
credential = decode_credential(encoded_pairing_credential_message)
return validate_credential(credential, host_static_pubkey)
def is_credential_autoconnect(credential: ThpPairingCredential) -> bool:
assert ThpPairingCredential.is_type_of(credential)
if credential.cred_metadata is None:
return False
if credential.cred_metadata.autoconnect is None:
return False
return credential.cred_metadata.autoconnect
def _encode_message_into_new_buffer(msg: protobuf.MessageType) -> bytes:
msg_len = protobuf.encoded_length(msg)
new_buffer = bytearray(msg_len)

View File

@ -31,7 +31,7 @@ from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage
from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods
from trezor.wire.thp.pairing_context import PairingContext
from .credential_manager import issue_credential
from .credential_manager import is_credential_autoconnect, issue_credential
if __debug__:
from trezor import log
@ -116,7 +116,8 @@ async def handle_pairing_request(
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
try:
response = await ctx.show_pairing_method_screen()
# Should raise UnexpectedMessageException
await ctx.show_pairing_method_screen()
except UnexpectedMessageException as e:
raw_response = e.msg
name = message_handler.get_msg_name(raw_response.type)
@ -137,12 +138,39 @@ async def handle_pairing_request(
else:
break
response = await _handle_different_pairing_methods(ctx, response)
response: protobuf.MessageType = await _handle_different_pairing_methods(
ctx, response
)
return await handle_credential_phase(
ctx,
message=response,
show_connection_dialog=False,
)
while ThpCredentialRequest.is_type_of(response):
response = await _handle_credential_request(ctx, response)
return await _handle_end_request(ctx, response)
@check_state_and_log(ChannelState.TC1)
async def handle_credential_phase(
ctx: PairingContext,
message: protobuf.MessageType,
show_connection_dialog: bool = True,
) -> ThpEndResponse:
autoconnect: bool = False
credential = ctx.channel_ctx.credential
if credential is not None:
autoconnect = is_credential_autoconnect(credential)
if credential.cred_metadata is not None:
ctx.host_name = credential.cred_metadata.host_name
if ctx.host_name is None:
raise Exception("Credential does not have a hostname")
if show_connection_dialog and not autoconnect:
await ctx.show_connection_dialogue()
while ThpCredentialRequest.is_type_of(message):
message = await _handle_credential_request(ctx, message)
return await _handle_end_request(ctx, message)
async def _prepare_pairing(ctx: PairingContext) -> None:
@ -375,8 +403,15 @@ async def _handle_credential_request(
if message.host_static_pubkey is None:
raise Exception("Invalid message") # TODO change failure type
autoconnect: bool = False
if message.autoconnect is not None:
autoconnect = message.autoconnect
trezor_static_pubkey = crypto.get_trezor_static_pubkey()
credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name)
credential_metadata = ThpCredentialMetadata(
host_name=ctx.host_name,
autoconnect=autoconnect,
)
credential = issue_credential(message.host_static_pubkey, credential_metadata)
return await ctx.call_any(

View File

@ -45,6 +45,8 @@ if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable
from trezor.messages import ThpPairingCredential
from .pairing_context import PairingContext
from .session_context import GenericSessionContext
@ -77,6 +79,7 @@ class Channel:
# Temporary objects
self.handshake: crypto.Handshake | None = None
self.credential: ThpPairingCredential | None = None
self.connection_context: PairingContext | None = None
self.busy_decoder: crypto.BusyDecoder | None = None
self.temp_crc: int | None = None

View File

@ -8,7 +8,7 @@ from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.errors import ActionCancelled, SilentError
from trezor.wire.protocol_common import Context, Message
from trezor.wire.thp import get_enabled_pairing_methods
from trezor.wire.thp import ChannelState, get_enabled_pairing_methods
if TYPE_CHECKING:
from typing import Awaitable, Container
@ -92,7 +92,7 @@ class PairingContext(Context):
self.display_data: PairingDisplayData = PairingDisplayData()
self.cpace: Cpace
self.host_name: str
self.host_name: str | None
async def handle(self) -> None:
next_message: Message | None = None
@ -115,7 +115,7 @@ class PairingContext(Context):
next_message = None
try:
next_message = await handle_pairing_request_message(self, message)
next_message = await handle_message(self, message)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
@ -219,6 +219,19 @@ class PairingContext(Context):
if result == trezorui_api.CONFIRMED:
await self.write(ThpPairingRequestApproved())
async def show_connection_dialogue(self) -> None:
from trezor.ui.layouts.common import interact
await interact(
trezorui_api.confirm_action(
title="Connection dialogue",
action="Do you want previously connected device to connect?",
description="Choose wisely! (or not)",
),
br_name="connection_request",
br_code=ButtonRequestType.Other,
)
async def show_pairing_method_screen(
self, selected_method: ThpPairingMethod | None = None
) -> UiResult:
@ -264,14 +277,14 @@ class PairingContext(Context):
return result
async def handle_pairing_request_message(
async def handle_message(
pairing_ctx: PairingContext,
msg: protocol_common.Message,
) -> protocol_common.Message | None:
res_msg: protobuf.MessageType | None = None
from apps.thp.pairing import handle_pairing_request
from apps.thp.pairing import handle_pairing_request, handle_credential_phase
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
@ -292,7 +305,10 @@ async def handle_pairing_request_message(
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handle_pairing_request(pairing_ctx, req_msg)
if pairing_ctx.channel_ctx.get_channel_state() == ChannelState.TC1:
task = handle_credential_phase(pairing_ctx, req_msg)
else:
task = handle_pairing_request(pairing_ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a

View File

@ -61,6 +61,7 @@ if __debug__:
_TREZOR_STATE_UNPAIRED = b"\x00"
_TREZOR_STATE_PAIRED = b"\x01"
_TREZOR_STATE_PAIRED_AUTOCONNECT = b"\x02"
async def handle_received_message(
@ -264,7 +265,7 @@ async def _handle_state_TH1(
async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None:
from apps.thp.credential_manager import validate_credential
from apps.thp.credential_manager import decode_credential, validate_credential
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "handle_state_TH2")
@ -322,21 +323,25 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH]
paired: bool = False
trezor_state = _TREZOR_STATE_UNPAIRED
if noise_payload.host_pairing_credential is not None:
try: # TODO change try-except for something better
credential = decode_credential(noise_payload.host_pairing_credential)
paired = validate_credential(
noise_payload.host_pairing_credential,
credential,
host_static_pubkey,
)
if paired:
trezor_state = _TREZOR_STATE_PAIRED
ctx.credential = credential
else:
ctx.credential = None
except DataError as e:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.exception(__name__, e)
pass
trezor_state = _TREZOR_STATE_UNPAIRED
if paired:
trezor_state = _TREZOR_STATE_PAIRED
# send hanshake completion response
ctx.write_handshake_message(
HANDSHAKE_COMP_RES,

View File

@ -48,16 +48,28 @@ class TestTrezorHostProtocolCredentialManager(unittest.TestCase):
cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_3)
self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2))
self.assertTrue(
credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_1)
)
self.assertTrue(
credential_manager.decode_and_validate_credential(cred_3, DUMMY_KEY_1)
)
self.assertFalse(
credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_2)
)
credential_manager.invalidate_cred_auth_key()
cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1)
self.assertNotEqual(cred_1, cred_4)
self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1))
self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1))
self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1))
self.assertFalse(
credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_1)
)
self.assertFalse(
credential_manager.decode_and_validate_credential(cred_3, DUMMY_KEY_1)
)
self.assertTrue(
credential_manager.decode_and_validate_credential(cred_4, DUMMY_KEY_1)
)
def test_protobuf_encoding(self):
"""

View File

@ -188,7 +188,9 @@ class ProtocolV2(ProtocolAndChannel):
device_properties = payload[10:]
return (channel_id, device_properties)
def _do_handshake(self):
def _do_handshake(
self, credential: bytes | None = None, host_static_privkey: bytes | None = None
):
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
@ -208,6 +210,8 @@ class ProtocolV2(ProtocolAndChannel):
host_ephemeral_privkey,
trezor_ephemeral_pubkey,
encrypted_trezor_static_pubkey,
credential,
host_static_privkey,
)
self._read_ack()
self._read_handshake_completion_response()
@ -246,6 +250,7 @@ class ProtocolV2(ProtocolAndChannel):
trezor_ephemeral_pubkey: bytes,
encrypted_trezor_static_pubkey: bytes,
credential: bytes | None = None,
host_static_privkey: bytes | None = None,
) -> bytes:
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
@ -275,16 +280,25 @@ class ProtocolV2(ProtocolAndChannel):
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials (or possibly not, as we skip pairing phase)
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey)
# TODO: search for saved credentials
if host_static_privkey is not None and credential is not None:
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
else:
credential = None
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(
temp_host_static_privkey
)
host_static_privkey = temp_host_static_privkey
host_static_pubkey = temp_host_static_pubkey
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey)
ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(

View File

@ -17,6 +17,8 @@ from trezorlib.messages import (
ThpCodeEntryCpaceHostTag,
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCredentialRequest,
ThpCredentialResponse,
ThpEndRequest,
ThpEndResponse,
ThpNfcTagHost,
@ -55,6 +57,18 @@ def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2:
return protocol
def _handle_pairing_request(client: Client, protocol: ProtocolV2) -> None:
protocol._send_message(ThpPairingRequest())
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "pairing_request"
protocol._send_message(ButtonAck())
client.debug.press_yes()
protocol._read_message(ThpPairingRequestApproved)
def test_allocate_channel(client: Client) -> None:
protocol = _prepare_protocol(client)
@ -112,14 +126,7 @@ def test_handshake(client: Client) -> None:
def test_pairing_qr_code(client: Client) -> None:
protocol = _prepare_protocol_for_pairing(client)
protocol._send_message(ThpPairingRequest())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
client.debug.press_yes()
protocol._read_message(ThpPairingRequestApproved)
_handle_pairing_request(client, protocol)
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode)
)
@ -161,13 +168,8 @@ def test_pairing_qr_code(client: Client) -> None:
def test_pairing_code_entry(client: Client) -> None:
protocol = _prepare_protocol_for_pairing(client)
protocol._send_message(ThpPairingRequest())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
_handle_pairing_request(client, protocol)
client.debug.press_yes()
protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry)
)
@ -227,13 +229,17 @@ def test_pairing_code_entry(client: Client) -> None:
def test_pairing_nfc(client: Client) -> None:
protocol = _prepare_protocol_for_pairing(client)
protocol._send_message(ThpPairingRequest())
protocol._read_message(ButtonRequest)
protocol._send_message(ButtonAck())
_nfc_pairing(client, protocol)
client.debug.press_yes()
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True
def _nfc_pairing(client: Client, protocol: ProtocolV2):
_handle_pairing_request(client, protocol)
protocol._read_message(ThpPairingRequestApproved)
protocol._send_message(
ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC)
)
@ -272,7 +278,55 @@ def test_pairing_nfc(client: Client) -> None:
computed_tag = sha_ctx.digest()
assert tag_trezor_msg.tag == computed_tag
def test_credential_phase(client: Client):
protocol = _prepare_protocol_for_pairing(client)
_nfc_pairing(client, protocol)
# Request credential with confirmation after pairing
host_static_privkey = curve25519.get_private_key(os.urandom(32))
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
protocol._send_message(
ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False)
)
credential_response = protocol._read_message(ThpCredentialResponse)
assert credential_response.credential is not None
credential = credential_response.credential
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
protocol._has_valid_channel = True
# Connect using credential with confirmation
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential, host_static_privkey)
protocol._send_message(ThpEndRequest())
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "connection_request"
protocol._send_message(ButtonAck())
client.debug.press_yes()
protocol._read_message(ThpEndResponse)
# Connect using credential with confirmation and ask for autoconnect credential
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential, host_static_privkey)
protocol._send_message(
ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True)
)
button_req = protocol._read_message(ButtonRequest)
assert button_req.name == "connection_request"
protocol._send_message(ButtonAck())
client.debug.press_yes()
credential_response_2 = protocol._read_message(ThpCredentialResponse)
assert credential_response_2.credential is not None
credential_auto = credential_response_2.credential
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)
# Connect using autoconnect credential
protocol = _prepare_protocol(client)
protocol._do_channel_allocation()
protocol._do_handshake(credential_auto, host_static_privkey)
protocol._send_message(ThpEndRequest())
protocol._read_message(ThpEndResponse)