diff --git a/core/src/apps/thp/credential_manager.py b/core/src/apps/thp/credential_manager.py index adf2ba6240..9170a06f9e 100644 --- a/core/src/apps/thp/credential_manager.py +++ b/core/src/apps/thp/credential_manager.py @@ -63,17 +63,26 @@ def issue_credential( return credential_raw -def validate_credential( +def decode_credential( encoded_pairing_credential_message: bytes, +) -> ThpPairingCredential: + """ + Decode a protobuf encoded pairing credential. + """ + expected_type = protobuf.type_for_name("ThpPairingCredential") + credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type) + assert ThpPairingCredential.is_type_of(credential) + return credential + + +def validate_credential( + credential: ThpPairingCredential, host_static_pubkey: bytes, ) -> bool: """ Validate a pairing credential binded to the provided host static public key. """ cred_auth_key = derive_cred_auth_key() - expected_type = protobuf.type_for_name("ThpPairingCredential") - credential = wrap_protobuf_load(encoded_pairing_credential_message, expected_type) - assert ThpPairingCredential.is_type_of(credential) proto_msg = ThpAuthenticatedCredentialData( host_static_pubkey=host_static_pubkey, cred_metadata=credential.cred_metadata, @@ -83,6 +92,27 @@ def validate_credential( return mac == credential.mac +def decode_and_validate_credential( + encoded_pairing_credential_message: bytes, + host_static_pubkey: bytes, +) -> bool: + """ + Decode a protobuf encoded pairing credential and validate it + binded to the provided host static public key. + """ + credential = decode_credential(encoded_pairing_credential_message) + return validate_credential(credential, host_static_pubkey) + + +def is_credential_autoconnect(credential: ThpPairingCredential) -> bool: + assert ThpPairingCredential.is_type_of(credential) + if credential.cred_metadata is None: + return False + if credential.cred_metadata.autoconnect is None: + return False + return credential.cred_metadata.autoconnect + + def _encode_message_into_new_buffer(msg: protobuf.MessageType) -> bytes: msg_len = protobuf.encoded_length(msg) new_buffer = bytearray(msg_len) diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index b3eeb79956..a195f5420f 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -31,7 +31,7 @@ from trezor.wire.errors import ActionCancelled, SilentError, UnexpectedMessage from trezor.wire.thp import ChannelState, ThpError, crypto, get_enabled_pairing_methods from trezor.wire.thp.pairing_context import PairingContext -from .credential_manager import issue_credential +from .credential_manager import is_credential_autoconnect, issue_credential if __debug__: from trezor import log @@ -116,7 +116,8 @@ async def handle_pairing_request( ctx.channel_ctx.set_channel_state(ChannelState.TP3) try: - response = await ctx.show_pairing_method_screen() + # Should raise UnexpectedMessageException + await ctx.show_pairing_method_screen() except UnexpectedMessageException as e: raw_response = e.msg name = message_handler.get_msg_name(raw_response.type) @@ -137,12 +138,39 @@ async def handle_pairing_request( else: break - response = await _handle_different_pairing_methods(ctx, response) + response: protobuf.MessageType = await _handle_different_pairing_methods( + ctx, response + ) + return await handle_credential_phase( + ctx, + message=response, + show_connection_dialog=False, + ) - while ThpCredentialRequest.is_type_of(response): - response = await _handle_credential_request(ctx, response) - return await _handle_end_request(ctx, response) +@check_state_and_log(ChannelState.TC1) +async def handle_credential_phase( + ctx: PairingContext, + message: protobuf.MessageType, + show_connection_dialog: bool = True, +) -> ThpEndResponse: + autoconnect: bool = False + credential = ctx.channel_ctx.credential + + if credential is not None: + autoconnect = is_credential_autoconnect(credential) + if credential.cred_metadata is not None: + ctx.host_name = credential.cred_metadata.host_name + if ctx.host_name is None: + raise Exception("Credential does not have a hostname") + + if show_connection_dialog and not autoconnect: + await ctx.show_connection_dialogue() + + while ThpCredentialRequest.is_type_of(message): + message = await _handle_credential_request(ctx, message) + + return await _handle_end_request(ctx, message) async def _prepare_pairing(ctx: PairingContext) -> None: @@ -375,8 +403,15 @@ async def _handle_credential_request( if message.host_static_pubkey is None: raise Exception("Invalid message") # TODO change failure type + autoconnect: bool = False + if message.autoconnect is not None: + autoconnect = message.autoconnect + trezor_static_pubkey = crypto.get_trezor_static_pubkey() - credential_metadata = ThpCredentialMetadata(host_name=ctx.host_name) + credential_metadata = ThpCredentialMetadata( + host_name=ctx.host_name, + autoconnect=autoconnect, + ) credential = issue_credential(message.host_static_pubkey, credential_metadata) return await ctx.call_any( diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 4913c93bd5..39baa27e49 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -45,6 +45,8 @@ if TYPE_CHECKING: from trezorio import WireInterface from typing import Awaitable + from trezor.messages import ThpPairingCredential + from .pairing_context import PairingContext from .session_context import GenericSessionContext @@ -77,6 +79,7 @@ class Channel: # Temporary objects self.handshake: crypto.Handshake | None = None + self.credential: ThpPairingCredential | None = None self.connection_context: PairingContext | None = None self.busy_decoder: crypto.BusyDecoder | None = None self.temp_crc: int | None = None diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 3b45dcb9dd..d5e911e563 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -8,7 +8,7 @@ from trezor.wire import context, message_handler, protocol_common from trezor.wire.context import UnexpectedMessageException from trezor.wire.errors import ActionCancelled, SilentError from trezor.wire.protocol_common import Context, Message -from trezor.wire.thp import get_enabled_pairing_methods +from trezor.wire.thp import ChannelState, get_enabled_pairing_methods if TYPE_CHECKING: from typing import Awaitable, Container @@ -92,7 +92,7 @@ class PairingContext(Context): self.display_data: PairingDisplayData = PairingDisplayData() self.cpace: Cpace - self.host_name: str + self.host_name: str | None async def handle(self) -> None: next_message: Message | None = None @@ -115,7 +115,7 @@ class PairingContext(Context): next_message = None try: - next_message = await handle_pairing_request_message(self, message) + next_message = await handle_message(self, message) except Exception as exc: # Log and ignore. The session handler can only exit explicitly in the # following finally block. @@ -219,6 +219,19 @@ class PairingContext(Context): if result == trezorui_api.CONFIRMED: await self.write(ThpPairingRequestApproved()) + async def show_connection_dialogue(self) -> None: + from trezor.ui.layouts.common import interact + + await interact( + trezorui_api.confirm_action( + title="Connection dialogue", + action="Do you want previously connected device to connect?", + description="Choose wisely! (or not)", + ), + br_name="connection_request", + br_code=ButtonRequestType.Other, + ) + async def show_pairing_method_screen( self, selected_method: ThpPairingMethod | None = None ) -> UiResult: @@ -264,14 +277,14 @@ class PairingContext(Context): return result -async def handle_pairing_request_message( +async def handle_message( pairing_ctx: PairingContext, msg: protocol_common.Message, ) -> protocol_common.Message | None: res_msg: protobuf.MessageType | None = None - from apps.thp.pairing import handle_pairing_request + from apps.thp.pairing import handle_pairing_request, handle_credential_phase if msg.type in workflow.ALLOW_WHILE_LOCKED: workflow.autolock_interrupts_workflow = False @@ -292,7 +305,10 @@ async def handle_pairing_request_message( req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) # Create the handler task. - task = handle_pairing_request(pairing_ctx, req_msg) + if pairing_ctx.channel_ctx.get_channel_state() == ChannelState.TC1: + task = handle_credential_phase(pairing_ctx, req_msg) + else: + task = handle_pairing_request(pairing_ctx, req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 5efc066e99..f176615387 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -61,6 +61,7 @@ if __debug__: _TREZOR_STATE_UNPAIRED = b"\x00" _TREZOR_STATE_PAIRED = b"\x01" +_TREZOR_STATE_PAIRED_AUTOCONNECT = b"\x02" async def handle_received_message( @@ -264,7 +265,7 @@ async def _handle_state_TH1( async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -> None: - from apps.thp.credential_manager import validate_credential + from apps.thp.credential_manager import decode_credential, validate_credential if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.debug(__name__, "handle_state_TH2") @@ -322,21 +323,25 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) - host_static_pubkey = host_encrypted_static_pubkey[:PUBKEY_LENGTH] paired: bool = False + trezor_state = _TREZOR_STATE_UNPAIRED if noise_payload.host_pairing_credential is not None: try: # TODO change try-except for something better + credential = decode_credential(noise_payload.host_pairing_credential) paired = validate_credential( - noise_payload.host_pairing_credential, + credential, host_static_pubkey, ) + if paired: + trezor_state = _TREZOR_STATE_PAIRED + ctx.credential = credential + else: + ctx.credential = None except DataError as e: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.exception(__name__, e) pass - trezor_state = _TREZOR_STATE_UNPAIRED - if paired: - trezor_state = _TREZOR_STATE_PAIRED # send hanshake completion response ctx.write_handshake_message( HANDSHAKE_COMP_RES, diff --git a/core/tests/test_apps.thp.credential_manager.py b/core/tests/test_apps.thp.credential_manager.py index 267707d374..e25cc6a002 100644 --- a/core/tests/test_apps.thp.credential_manager.py +++ b/core/tests/test_apps.thp.credential_manager.py @@ -48,16 +48,28 @@ class TestTrezorHostProtocolCredentialManager(unittest.TestCase): cred_3 = _issue_credential(HOST_NAME_2, DUMMY_KEY_1) self.assertNotEqual(cred_1, cred_3) - self.assertTrue(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) - self.assertTrue(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) - self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_2)) + self.assertTrue( + credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_1) + ) + self.assertTrue( + credential_manager.decode_and_validate_credential(cred_3, DUMMY_KEY_1) + ) + self.assertFalse( + credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_2) + ) credential_manager.invalidate_cred_auth_key() cred_4 = _issue_credential(HOST_NAME_1, DUMMY_KEY_1) self.assertNotEqual(cred_1, cred_4) - self.assertFalse(credential_manager.validate_credential(cred_1, DUMMY_KEY_1)) - self.assertFalse(credential_manager.validate_credential(cred_3, DUMMY_KEY_1)) - self.assertTrue(credential_manager.validate_credential(cred_4, DUMMY_KEY_1)) + self.assertFalse( + credential_manager.decode_and_validate_credential(cred_1, DUMMY_KEY_1) + ) + self.assertFalse( + credential_manager.decode_and_validate_credential(cred_3, DUMMY_KEY_1) + ) + self.assertTrue( + credential_manager.decode_and_validate_credential(cred_4, DUMMY_KEY_1) + ) def test_protobuf_encoding(self): """ diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 1604ae64ea..2eaa92d648 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -188,7 +188,9 @@ class ProtocolV2(ProtocolAndChannel): device_properties = payload[10:] return (channel_id, device_properties) - def _do_handshake(self): + def _do_handshake( + self, credential: bytes | None = None, host_static_privkey: bytes | None = None + ): host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) @@ -208,6 +210,8 @@ class ProtocolV2(ProtocolAndChannel): host_ephemeral_privkey, trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, + credential, + host_static_privkey, ) self._read_ack() self._read_handshake_completion_response() @@ -246,6 +250,7 @@ class ProtocolV2(ProtocolAndChannel): trezor_ephemeral_pubkey: bytes, encrypted_trezor_static_pubkey: bytes, credential: bytes | None = None, + host_static_privkey: bytes | None = None, ) -> bytes: PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" @@ -275,16 +280,25 @@ class ProtocolV2(ProtocolAndChannel): tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) h = _sha256_of_two(h, tag_of_empty_string) - # TODO: search for saved credentials (or possibly not, as we skip pairing phase) - zeroes_32 = int.to_bytes(0, 32, "little") - temp_host_static_privkey = curve25519.get_private_key(zeroes_32) - temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) + # TODO: search for saved credentials + if host_static_privkey is not None and credential is not None: + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + else: + credential = None + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key( + temp_host_static_privkey + ) + host_static_privkey = temp_host_static_privkey + host_static_pubkey = temp_host_static_pubkey + aes_ctx = AESGCM(k) - encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h) h = _sha256_of_two(h, encrypted_host_static_pubkey) ck, k = _hkdf( - ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) + ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey) ) msg_data = self.mapping.encode_without_wire_type( messages.ThpHandshakeCompletionReqNoisePayload( diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_thp.py index 6b4569ca5d..2b7c7fe1f9 100644 --- a/tests/device_tests/thp/test_thp.py +++ b/tests/device_tests/thp/test_thp.py @@ -17,6 +17,8 @@ from trezorlib.messages import ( ThpCodeEntryCpaceHostTag, ThpCodeEntryCpaceTrezor, ThpCodeEntrySecret, + ThpCredentialRequest, + ThpCredentialResponse, ThpEndRequest, ThpEndResponse, ThpNfcTagHost, @@ -55,6 +57,18 @@ def _prepare_protocol_for_pairing(client: Client) -> ProtocolV2: return protocol +def _handle_pairing_request(client: Client, protocol: ProtocolV2) -> None: + protocol._send_message(ThpPairingRequest()) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "pairing_request" + + protocol._send_message(ButtonAck()) + + client.debug.press_yes() + + protocol._read_message(ThpPairingRequestApproved) + + def test_allocate_channel(client: Client) -> None: protocol = _prepare_protocol(client) @@ -112,14 +126,7 @@ def test_handshake(client: Client) -> None: def test_pairing_qr_code(client: Client) -> None: protocol = _prepare_protocol_for_pairing(client) - - protocol._send_message(ThpPairingRequest()) - protocol._read_message(ButtonRequest) - protocol._send_message(ButtonAck()) - - client.debug.press_yes() - - protocol._read_message(ThpPairingRequestApproved) + _handle_pairing_request(client, protocol) protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) ) @@ -161,13 +168,8 @@ def test_pairing_qr_code(client: Client) -> None: def test_pairing_code_entry(client: Client) -> None: protocol = _prepare_protocol_for_pairing(client) - protocol._send_message(ThpPairingRequest()) - protocol._read_message(ButtonRequest) - protocol._send_message(ButtonAck()) + _handle_pairing_request(client, protocol) - client.debug.press_yes() - - protocol._read_message(ThpPairingRequestApproved) protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) ) @@ -227,13 +229,17 @@ def test_pairing_code_entry(client: Client) -> None: def test_pairing_nfc(client: Client) -> None: protocol = _prepare_protocol_for_pairing(client) - protocol._send_message(ThpPairingRequest()) - protocol._read_message(ButtonRequest) - protocol._send_message(ButtonAck()) + _nfc_pairing(client, protocol) - client.debug.press_yes() + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + protocol._has_valid_channel = True + + +def _nfc_pairing(client: Client, protocol: ProtocolV2): + + _handle_pairing_request(client, protocol) - protocol._read_message(ThpPairingRequestApproved) protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC) ) @@ -272,7 +278,55 @@ def test_pairing_nfc(client: Client) -> None: computed_tag = sha_ctx.digest() assert tag_trezor_msg.tag == computed_tag + +def test_credential_phase(client: Client): + protocol = _prepare_protocol_for_pairing(client) + _nfc_pairing(client, protocol) + + # Request credential with confirmation after pairing + host_static_privkey = curve25519.get_private_key(os.urandom(32)) + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False) + ) + credential_response = protocol._read_message(ThpCredentialResponse) + + assert credential_response.credential is not None + credential = credential_response.credential protocol._send_message(ThpEndRequest()) protocol._read_message(ThpEndResponse) - protocol._has_valid_channel = True + # Connect using credential with confirmation + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential, host_static_privkey) + protocol._send_message(ThpEndRequest()) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + protocol._read_message(ThpEndResponse) + + # Connect using credential with confirmation and ask for autoconnect credential + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential, host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True) + ) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "connection_request" + protocol._send_message(ButtonAck()) + client.debug.press_yes() + credential_response_2 = protocol._read_message(ThpCredentialResponse) + assert credential_response_2.credential is not None + credential_auto = credential_response_2.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + # Connect using autoconnect credential + protocol = _prepare_protocol(client) + protocol._do_channel_allocation() + protocol._do_handshake(credential_auto, host_static_privkey) + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse)