diff --git a/core/src/apps/base.py b/core/src/apps/base.py index b8650e6861..08b35ef5a6 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: Failure, ThpCreateNewSession, ThpCredentialRequest, - ThpPairingCredential, + ThpCredentialResponse, ) if utils.USE_THP: @@ -283,10 +283,11 @@ if utils.USE_THP: async def handle_ThpCredentialRequest( message: ThpCredentialRequest, - ) -> ThpPairingCredential | Failure: + ) -> ThpCredentialResponse | Failure: from storage.cache_common import CHANNEL_HOST_STATIC_PUBKEY - from trezor.messages import ThpCredentialMetadata + from trezor.messages import ThpCredentialMetadata, ThpCredentialResponse from trezor.wire.context import get_context + from trezor.wire.thp import crypto from trezor.wire.thp.session_context import GenericSessionContext from apps.thp.credential_manager import ( @@ -325,13 +326,19 @@ if utils.USE_THP: await ui.show_autoconnect_credential_confirmation_screen( ctx, cred_metadata.host_name ) - return issue_credential( + new_cred = issue_credential( host_static_pubkey=host_static_pubkey, credential_metadata=cred_metadata, ) + trezor_static_pubkey = crypto.get_trezor_static_pubkey() + + return ThpCredentialResponse( + trezor_static_pubkey=trezor_static_pubkey, credential=new_cred + ) def _get_autoconnect_failure() -> Failure: from trezor.enums import FailureType + from trezor.messages import Failure return Failure( code=FailureType.DataError, diff --git a/tests/device_tests/thp/test_pairing.py b/tests/device_tests/thp/test_pairing.py index 7d367000a1..2e6fc55421 100644 --- a/tests/device_tests/thp/test_pairing.py +++ b/tests/device_tests/thp/test_pairing.py @@ -426,6 +426,36 @@ def test_credential_phase(client: Client) -> None: protocol._read_message(ThpEndResponse) +def test_credential_request_in_encrypted_transport_phase(client: Client) -> None: + randomness_static = os.urandom(32) + protocol = prepare_protocol_for_pairing(client, randomness_static) + _nfc_pairing(client, protocol) + + # Request credential with confirmation after pairing + host_static_privkey = curve25519.get_private_key(randomness_static) + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + protocol._send_message( + ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False) + ) + credential_response = protocol._read_message(ThpCredentialResponse) + + assert credential_response.credential is not None + credential = credential_response.credential + protocol._send_message(ThpEndRequest()) + protocol._read_message(ThpEndResponse) + + session = client.get_seedless_session() + + session.call( + ThpCredentialRequest( + host_static_pubkey=host_static_pubkey, + autoconnect=True, + credential=credential, + ), + expect=ThpCredentialResponse, + ) + + @pytest.mark.setup_client(passphrase=True) def test_channel_replacement(client: Client) -> None: assert client.features.passphrase_protection is True