From a6db6b2dd6a44c0b322a9f10eb2b911bb0246b1c Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 24 Mar 2025 11:16:38 +0100 Subject: [PATCH] chore(test): improve tests for multiple hosts --- tests/device_tests/thp/test_multiple_apps.py | 63 ------- tests/device_tests/thp/test_multiple_hosts.py | 177 ++++++++++++++++++ 2 files changed, 177 insertions(+), 63 deletions(-) delete mode 100644 tests/device_tests/thp/test_multiple_apps.py create mode 100644 tests/device_tests/thp/test_multiple_hosts.py diff --git a/tests/device_tests/thp/test_multiple_apps.py b/tests/device_tests/thp/test_multiple_apps.py deleted file mode 100644 index 0ba083edf5..0000000000 --- a/tests/device_tests/thp/test_multiple_apps.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -import time - -import pytest - -from trezorlib.client import ProtocolV2Channel -from trezorlib.debuglink import TrezorClientDebugLink as Client - -pytestmark = [pytest.mark.protocol("protocol_v2")] - - -def test_multiple_hosts(client: Client) -> None: - assert isinstance(client.protocol, ProtocolV2Channel) - protocol_1 = client.protocol - protocol_2 = ProtocolV2Channel(protocol_1.transport, protocol_1.mapping) - protocol_2._reset_sync_bits() - - nonce_1 = os.urandom(8) - nonce_2 = os.urandom(8) - if nonce_1 == nonce_2: - nonce_2 = (int.from_bytes(nonce_1) + 1).to_bytes(8, "big") - protocol_1._send_channel_allocation_request(nonce_1) - protocol_1.channel_id, protocol_1.device_properties = ( - protocol_1._read_channel_allocation_response(nonce_1) - ) - protocol_2._send_channel_allocation_request(nonce_2) - protocol_2.channel_id, protocol_2.device_properties = ( - protocol_2._read_channel_allocation_response(nonce_2) - ) - - protocol_1._init_noise() - protocol_2._init_noise() - - protocol_1._send_handshake_init_request() - protocol_1._read_ack() - protocol_1._read_handshake_init_response() - - protocol_2._send_handshake_init_request() - - with pytest.raises(Exception) as e: - protocol_2._read_ack() - assert e.value.args[0] == "Received ThpError: TRANSPORT BUSY" - - time.sleep(0.2) # To pass LOCK_TIME - protocol_2._init_noise() - protocol_2._send_handshake_init_request() - protocol_2._read_ack() - protocol_2._read_handshake_init_response() - - protocol_2._send_handshake_completion_request() - protocol_2._read_ack() - protocol_2._read_handshake_completion_response() - - protocol_2._do_pairing(helper_debug=client.debug) - - protocol_1._send_handshake_completion_request() - protocol_1._read_ack() - - with pytest.raises(Exception) as e: - protocol_1._read_handshake_completion_response() - assert e.value.args[0] == "Received ThpError: UNALLOCATED CHANNEL" - - # TODO - test ACK fallback, test standard encrypted message fallback diff --git a/tests/device_tests/thp/test_multiple_hosts.py b/tests/device_tests/thp/test_multiple_hosts.py new file mode 100644 index 0000000000..5e38d414ad --- /dev/null +++ b/tests/device_tests/thp/test_multiple_hosts.py @@ -0,0 +1,177 @@ +import os +from time import sleep +import pytest + +from trezorlib import messages +from trezorlib.client import ProtocolV2Channel +from trezorlib.debuglink import TrezorClientDebugLink as Client +from ...conftest import LOCK_TIME + +pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client] + +# LOCK_TIME = 0.2 + + +def _prepare_two_hosts_for_handshake( + client: Client, init_noise: bool = True +) -> tuple[ProtocolV2Channel, ProtocolV2Channel]: + # Sleep for LOCK_TIME + sleep(LOCK_TIME) + + protocol_1 = client.protocol + protocol_1._reset_sync_bits() + protocol_2 = ProtocolV2Channel(protocol_1.transport, protocol_1.mapping) + protocol_2._reset_sync_bits() + + nonce_1 = os.urandom(8) + nonce_2 = os.urandom(8) + if nonce_1 == nonce_2: + nonce_2 = (int.from_bytes(nonce_1) + 1).to_bytes(8, "big") + protocol_1._send_channel_allocation_request(nonce_1) + protocol_1.channel_id, protocol_1.device_properties = ( + protocol_1._read_channel_allocation_response(nonce_1) + ) + protocol_2._send_channel_allocation_request(nonce_2) + protocol_2.channel_id, protocol_2.device_properties = ( + protocol_2._read_channel_allocation_response(nonce_2) + ) + if init_noise: + protocol_1._init_noise() + protocol_2._init_noise() + + return protocol_1, protocol_2 + + +def _prepare_two_hosts(client: Client) -> tuple[ProtocolV2Channel, ProtocolV2Channel]: + protocol_1, protocol_2 = _prepare_two_hosts_for_handshake( + client=client, init_noise=False + ) + protocol_1._do_handshake() + + protocol_1._do_pairing(client.debug) + sleep(LOCK_TIME) + protocol_2._do_handshake() + protocol_2._do_pairing(client.debug) + + return protocol_1, protocol_2 + + +def test_fallback_encrypted_transport(client: Client) -> None: + protocol_1, protocol_2 = _prepare_two_hosts(client) + client_1 = Client( + transport=client.transport, open_transport=True, protocol=protocol_1 + ) + client_2 = Client( + transport=client.transport, open_transport=True, protocol=protocol_2 + ) + session_1 = client_1.get_session(session_id=b"\x05") + session_2 = client_2.get_session(session_id=b"\x08") + msg = messages.GetFeatures() + + # Sequential calls should work without any problem + _ = session_1.call(msg) + _ = session_2.call(msg) + _ = session_1.call(msg) + _ = session_2.call(msg) + _ = session_1.call(msg) + _ = session_2.call(msg) + _ = session_1.call(msg) + _ = session_2.call(msg) + + # Zig-zag calls should invoke fallback + session_1._write(msg) + session_2._write(msg) + resp = session_2._read() + assert isinstance(resp, messages.Failure) + assert resp.message == "FALLBACK!" + sleep(LOCK_TIME) + session_2.call(msg) + + +def test_concurrent_handshakes_1(client: Client) -> None: + protocol_1, protocol_2 = _prepare_two_hosts_for_handshake(client) + + # The first host starts handshake + protocol_1._send_handshake_init_request() + protocol_1._read_ack() + protocol_1._read_handshake_init_response() + + # The second host starts handshake + protocol_2._send_handshake_init_request() + + # The second host should not be able to interrupt the first host's handshake + # until timeout (LOCK_TIME) has expired + with pytest.raises(Exception) as e: + protocol_2._read_ack() + assert e.value.args[0] == "Received ThpError: TRANSPORT BUSY" + + # Wait for LOCK_TIME to expire + sleep(LOCK_TIME) + + # The second host retries and finishes handhake successfully + protocol_2._init_noise() + protocol_2._send_handshake_init_request() + protocol_2._read_ack() + protocol_2._read_handshake_init_response() + + protocol_2._send_handshake_completion_request() + protocol_2._read_ack() + protocol_2._read_handshake_completion_response() + + # The second host performs action that results + # in the invalidation of the first host's handshake state + protocol_2._do_pairing(helper_debug=client.debug) + + # Even after LOCK_TIME passes, the first host's channel cannot + # be resumed + sleep(LOCK_TIME) + protocol_1._send_handshake_completion_request() + protocol_1._read_ack() + + with pytest.raises(Exception) as e: + protocol_1._read_handshake_completion_response() + assert e.value.args[0] == "Received ThpError: UNALLOCATED CHANNEL" + + +def test_concurrent_handshakes_2(client: Client) -> None: + protocol_1, protocol_2 = _prepare_two_hosts_for_handshake(client) + + # The first host starts handshake + protocol_1._send_handshake_init_request() + protocol_1._read_ack() + protocol_1._read_handshake_init_response() + + # The second host starts handshake + protocol_2._send_handshake_init_request() + + # The second host should not be able to interrupt the first host's handshake + # until timeout (LOCK_TIME) has expired + with pytest.raises(Exception) as e: + protocol_2._read_ack() + assert e.value.args[0] == "Received ThpError: TRANSPORT BUSY" + + # Wait for LOCK_TIME to expire + sleep(LOCK_TIME) + + # The second host retries and finishes handhake successfully + protocol_2._init_noise() + protocol_2._send_handshake_init_request() + protocol_2._read_ack() + protocol_2._read_handshake_init_response() + + protocol_2._send_handshake_completion_request() + protocol_2._read_ack() + protocol_2._read_handshake_completion_response() + + # The first host tries to continue handshake immediately after + # the second host finishes it + + protocol_1._send_handshake_completion_request() + + with pytest.raises(Exception) as e: + protocol_1._read_ack() + + # protocol_1._read_handshake_completion_response() + assert e.value.args[0] == "Received ThpError: TRANSPORT BUSY" + + # TODO - test ACK fallback, test standard encrypted message fallback