diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 82664fe189..eeef009555 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -141,11 +141,15 @@ class Channel: self._handle_received_packet(packet) + if self.expected_payload_length == 0: # Reading failed TODO + from trezor.wire.thp import ThpErrorType + + return self.write_error(ThpErrorType.TRANSPORT_BUSY) + try: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) except WireBufferError: pass # TODO ?? - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: try: self._log("self.buffer: ", get_bytes_as_str(buffer)) @@ -157,6 +161,7 @@ class Channel: if self.fallback_decrypt: # TODO Check CRC and if valid, check tag, if valid update nonces self._finish_fallback() + self.write() # TODO self.write() failure device is busy - use channel buffer to send this failure message!! return None return received_message_handler.handle_received_message(self, buffer) @@ -204,7 +209,14 @@ class Channel: self.fallback_decrypt = True - self._prepare_fallback() + try: + self._prepare_fallback() + except Exception: + self.fallback_decrypt = False + self.expected_payload_length = 0 + self.bytes_read = 0 + print("FAILED TO FALLBACK") + return to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length) buf = memoryview(self.buffer)[:to_read_len] @@ -417,7 +429,8 @@ class Channel: buffer, msg, session_id ) except WireBufferError: - from trezor.messages import Failure, FailureType + from trezor.enums import FailureType + from trezor.messages import Failure if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("Failed to get write buffer, killing channel.") diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 47dcd471e8..df4a77ced3 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -130,7 +130,9 @@ class ProtocolV2Channel(Channel): device_properties = payload[10:] return (channel_id, device_properties) - def _init_noise(self, randomness_static: bytes) -> None: + def _init_noise(self, randomness_static: bytes | None = None) -> None: + if randomness_static is None: + randomness_static = os.urandom(32) self._noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256") self._noise.set_as_initiator() self._noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static) diff --git a/tests/device_tests/thp/connect.py b/tests/device_tests/thp/connect.py new file mode 100644 index 0000000000..ac2e81d6f9 --- /dev/null +++ b/tests/device_tests/thp/connect.py @@ -0,0 +1,48 @@ +from trezorlib.client import ProtocolV2Channel +from trezorlib.debuglink import TrezorClientDebugLink as Client +from trezorlib.messages import ( + ButtonAck, + ButtonRequest, + ThpPairingRequest, + ThpPairingRequestApproved, +) + + +def prepare_protocol_for_handshake(client: Client) -> ProtocolV2Channel: + protocol = client.protocol + assert isinstance(protocol, ProtocolV2Channel) + protocol._reset_sync_bits() + protocol._do_channel_allocation() + return protocol + + +def prepare_protocol_for_pairing( + client: Client, host_static_randomness: bytes | None = None +) -> ProtocolV2Channel: + protocol = prepare_protocol_for_handshake(client) + protocol._do_handshake(host_static_randomness=host_static_randomness) + return protocol + + +def get_encrypted_transport_protocol( + client: Client, host_static_randomness: bytes | None = None +) -> ProtocolV2Channel: + protocol = prepare_protocol_for_pairing( + client, host_static_randomness=host_static_randomness + ) + protocol._do_pairing(client.debug) + return protocol + + +def handle_pairing_request( + client: Client, protocol: ProtocolV2Channel, host_name: str | None = None +) -> None: + protocol._send_message(ThpPairingRequest(host_name=host_name)) + button_req = protocol._read_message(ButtonRequest) + assert button_req.name == "pairing_request" + + protocol._send_message(ButtonAck()) + + client.debug.press_yes() + + protocol._read_message(ThpPairingRequestApproved) diff --git a/tests/device_tests/thp/test_handshake.py b/tests/device_tests/thp/test_handshake.py new file mode 100644 index 0000000000..b773a7bc5c --- /dev/null +++ b/tests/device_tests/thp/test_handshake.py @@ -0,0 +1,51 @@ +import os +import pytest + +from trezorlib.client import ProtocolV2Channel +from trezorlib.debuglink import TrezorClientDebugLink as Client + + +from .connect import prepare_protocol_for_handshake + + +pytestmark = [pytest.mark.protocol("protocol_v2")] + + +def test_allocate_channel(client: Client) -> None: + assert isinstance(client.protocol, ProtocolV2Channel) + + nonce = os.urandom(8) + + # Use valid nonce + client.protocol._send_channel_allocation_request(nonce) + client.protocol._read_channel_allocation_response(nonce) + + # Expect different nonce + client.protocol._send_channel_allocation_request(nonce) + with pytest.raises(Exception, match="Invalid channel allocation response."): + client.protocol._read_channel_allocation_response( + expected_nonce=b"\xde\xad\xbe\xef\xde\xad\xbe\xef" + ) + client.invalidate() + + +def test_handshake(client: Client) -> None: + protocol = prepare_protocol_for_handshake(client) + + randomness_static = os.urandom(32) + + protocol._do_channel_allocation() + protocol._init_noise( + randomness_static=randomness_static, + ) + protocol._send_handshake_init_request() + protocol._read_ack() + protocol._read_handshake_init_response() + + protocol._send_handshake_completion_request() + protocol._read_ack() + protocol._read_handshake_completion_response() + + # TODO - without pairing, the client is damaged and results in fail of the following test + # so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error + protocol._do_pairing(client.debug) diff --git a/tests/device_tests/thp/test_multiple_apps.py b/tests/device_tests/thp/test_multiple_apps.py new file mode 100644 index 0000000000..adea18a2b9 --- /dev/null +++ b/tests/device_tests/thp/test_multiple_apps.py @@ -0,0 +1,39 @@ +import os +import time +import pytest +from trezorlib.client import ProtocolV2Channel + +from trezorlib.debuglink import TrezorClientDebugLink as Client + +pytestmark = [pytest.mark.protocol("protocol_v2")] + + +def test_multiple_hosts(client: Client) -> None: + assert isinstance(client.protocol, ProtocolV2Channel) + protocol_1 = client.protocol + protocol_2 = ProtocolV2Channel(protocol_1.transport, protocol_1.mapping) + + nonce_1 = os.urandom(8) + nonce_2 = os.urandom(8) + if nonce_1 == nonce_2: + nonce_2 = (int.from_bytes(nonce_1) + 1).to_bytes(8, "big") + protocol_1._send_channel_allocation_request(nonce_1) + protocol_1.channel_id, protocol_1.device_properties = ( + protocol_1._read_channel_allocation_response(nonce_1) + ) + protocol_2._send_channel_allocation_request(nonce_2) + protocol_2.channel_id, protocol_2.device_properties = ( + protocol_2._read_channel_allocation_response(nonce_2) + ) + + protocol_1._init_noise() + protocol_2._init_noise() + + protocol_1._send_handshake_init_request() + protocol_1._read_ack() + protocol_1._read_handshake_init_response() + + time.sleep(0.2) # To pass LOCK_TIME + protocol_2._send_handshake_init_request() + protocol_2._read_ack() + # protocol_2._read_handshake_init_response() diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_pairing.py similarity index 76% rename from tests/device_tests/thp/test_thp.py rename to tests/device_tests/thp/test_pairing.py index 7fbd899689..725c6908d4 100644 --- a/tests/device_tests/thp/test_thp.py +++ b/tests/device_tests/thp/test_pairing.py @@ -25,8 +25,6 @@ from trezorlib.messages import ( ThpNfcTagTrezor, ThpPairingMethod, ThpPairingPreparationsFinished, - ThpPairingRequest, - ThpPairingRequestApproved, ThpQrCodeSecret, ThpQrCodeTag, ThpSelectMethod, @@ -34,6 +32,13 @@ from trezorlib.messages import ( from trezorlib.transport.thp import curve25519 from trezorlib.transport.thp.cpace import Cpace +from .connect import ( + prepare_protocol_for_handshake, + prepare_protocol_for_pairing, + get_encrypted_transport_protocol, + handle_pairing_request, +) + if t.TYPE_CHECKING: P = tx.ParamSpec("P") @@ -42,89 +47,9 @@ MT = t.TypeVar("MT", bound=protobuf.MessageType) pytestmark = [pytest.mark.protocol("protocol_v2")] -def _prepare_protocol(client: Client) -> ProtocolV2Channel: - protocol = client.protocol - assert isinstance(protocol, ProtocolV2Channel) - protocol._reset_sync_bits() - protocol._do_channel_allocation() - return protocol - - -def _prepare_protocol_for_pairing( - client: Client, host_static_randomness: bytes | None = None -) -> ProtocolV2Channel: - protocol = _prepare_protocol(client) - protocol._do_handshake(host_static_randomness=host_static_randomness) - return protocol - - -def _get_encrypted_transport_protocol( - client: Client, host_static_randomness: bytes | None = None -) -> ProtocolV2Channel: - protocol = _prepare_protocol_for_pairing( - client, host_static_randomness=host_static_randomness - ) - protocol._do_pairing(client.debug) - return protocol - - -def _handle_pairing_request( - client: Client, protocol: ProtocolV2Channel, host_name: str | None = None -) -> None: - protocol._send_message(ThpPairingRequest(host_name=host_name)) - button_req = protocol._read_message(ButtonRequest) - assert button_req.name == "pairing_request" - - protocol._send_message(ButtonAck()) - - client.debug.press_yes() - - protocol._read_message(ThpPairingRequestApproved) - - -def test_allocate_channel(client: Client) -> None: - protocol = _prepare_protocol(client) - - nonce = os.urandom(8) - - # Use valid nonce - protocol._send_channel_allocation_request(nonce) - protocol._read_channel_allocation_response(nonce) - - # Expect different nonce - protocol._send_channel_allocation_request(nonce) - with pytest.raises(Exception, match="Invalid channel allocation response."): - protocol._read_channel_allocation_response( - expected_nonce=b"\xde\xad\xbe\xef\xde\xad\xbe\xef" - ) - client.invalidate() - - -def test_handshake(client: Client) -> None: - protocol = _prepare_protocol(client) - - randomness_static = os.urandom(32) - - protocol._do_channel_allocation() - protocol._init_noise( - randomness_static=randomness_static, - ) - protocol._send_handshake_init_request() - protocol._read_ack() - protocol._read_handshake_init_response() - - protocol._send_handshake_completion_request() - protocol._read_ack() - protocol._read_handshake_completion_response() - - # TODO - without pairing, the client is damaged and results in fail of the following test - # so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error - protocol._do_pairing(client.debug) - - def test_pairing_qr_code(client: Client) -> None: - protocol = _prepare_protocol_for_pairing(client) - _handle_pairing_request(client, protocol, "TestTrezor QrCode") + protocol = prepare_protocol_for_pairing(client) + handle_pairing_request(client, protocol, "TestTrezor QrCode") protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.QrCode) ) @@ -162,9 +87,9 @@ def test_pairing_qr_code(client: Client) -> None: def test_pairing_code_entry(client: Client) -> None: - protocol = _prepare_protocol_for_pairing(client) + protocol = prepare_protocol_for_pairing(client) - _handle_pairing_request(client, protocol, "TestTrezor CodeEntry") + handle_pairing_request(client, protocol, "TestTrezor CodeEntry") protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.CodeEntry) @@ -221,7 +146,7 @@ def test_pairing_code_entry(client: Client) -> None: def test_pairing_nfc(client: Client) -> None: - protocol = _prepare_protocol_for_pairing(client) + protocol = prepare_protocol_for_pairing(client) _nfc_pairing(client, protocol) @@ -232,7 +157,7 @@ def test_pairing_nfc(client: Client) -> None: def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None: - _handle_pairing_request(client, protocol, "TestTrezor NfcPairing") + handle_pairing_request(client, protocol, "TestTrezor NfcPairing") protocol._send_message( ThpSelectMethod(selected_pairing_method=ThpPairingMethod.NFC) @@ -272,7 +197,7 @@ def _nfc_pairing(client: Client, protocol: ProtocolV2Channel) -> None: def test_credential_phase(client: Client) -> None: - protocol = _prepare_protocol_for_pairing(client) + protocol = prepare_protocol_for_pairing(client) _nfc_pairing(client, protocol) # Request credential with confirmation after pairing @@ -290,7 +215,7 @@ def test_credential_phase(client: Client) -> None: protocol._read_message(ThpEndResponse) # Connect using credential with confirmation - protocol = _prepare_protocol(client) + protocol = prepare_protocol_for_handshake(client) protocol._do_channel_allocation() protocol._do_handshake(credential, randomness_static) protocol._send_message(ThpEndRequest()) @@ -310,7 +235,7 @@ def test_credential_phase(client: Client) -> None: assert e.value.args[0] == "Received ThpError: DECRYPTION FAILED" # Connect using credential with confirmation and ask for autoconnect credential. - protocol = _prepare_protocol(client) + protocol = prepare_protocol_for_handshake(client) protocol._do_channel_allocation() protocol._do_handshake(credential, randomness_static) protocol._send_message( @@ -334,7 +259,7 @@ def test_credential_phase(client: Client) -> None: protocol._read_message(ThpEndResponse) # Connect using credential with confirmation - protocol = _prepare_protocol(client) + protocol = prepare_protocol_for_handshake(client) protocol._do_channel_allocation() protocol._do_handshake(credential, randomness_static) # Confirmation dialog is not shown as channel in ENCRYPTED TRANSPORT state with the same @@ -343,7 +268,7 @@ def test_credential_phase(client: Client) -> None: protocol._read_message(ThpEndResponse) # Connect using autoconnect credential - protocol = _prepare_protocol(client) + protocol = prepare_protocol_for_handshake(client) protocol._do_channel_allocation() protocol._do_handshake(credential_auto, randomness_static) protocol._send_message(ThpEndRequest()) @@ -359,7 +284,7 @@ def test_credential_phase(client: Client) -> None: assert e.value.args[0] == "Received ThpError: DECRYPTION FAILED" # Connect using autoconnect credential - should work the same as above - protocol = _prepare_protocol(client) + protocol = prepare_protocol_for_handshake(client) protocol._do_channel_allocation() protocol._do_handshake(credential_auto, randomness_static) protocol._send_message(ThpEndRequest()) @@ -377,7 +302,7 @@ def test_channel_replacement(client: Client) -> None: assert host_static_privkey != host_static_privkey_2 - client.protocol = _get_encrypted_transport_protocol(client, host_static_randomness) + client.protocol = get_encrypted_transport_protocol(client, host_static_randomness) session = client.get_session(passphrase="TREZOR", session_id=b"\x10") address = get_test_address(session) @@ -387,7 +312,7 @@ def test_channel_replacement(client: Client) -> None: assert address != address_2 # create new channel using the same host_static_privkey - client.protocol = _get_encrypted_transport_protocol(client, host_static_randomness) + client.protocol = get_encrypted_transport_protocol(client, host_static_randomness) session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30") address_3 = get_test_address(session_3) assert address_3 != address_2 @@ -399,9 +324,7 @@ def test_channel_replacement(client: Client) -> None: assert address_3 == new_address_3 # create new channel using different host_static_privkey - client.protocol = _get_encrypted_transport_protocol( - client, host_static_randomness_2 - ) + client.protocol = get_encrypted_transport_protocol(client, host_static_randomness_2) with pytest.raises(exceptions.TrezorFailure) as e_1: _ = get_test_address(session) assert str(e_1.value.message) == "Invalid session"