diff --git a/core/tests/test_trezor.wire.thp.py b/core/tests/test_trezor.wire.thp.py new file mode 100644 index 000000000..704312098 --- /dev/null +++ b/core/tests/test_trezor.wire.thp.py @@ -0,0 +1,71 @@ +from common import * + +from trezor import io +from trezor.loop import wait +from trezor.wire import thp_v1 +from trezor.wire.thp import channel +from storage import cache_thp +from ubinascii import hexlify +from trezor.wire.thp import ChannelState + + +class MockHID: + def __init__(self, num): + self.num = num + self.data = [] + + def iface_num(self): + return self.num + + def write(self, msg): + self.data.append(bytearray(msg)) + return len(msg) + + def wait_object(self, mode): + return wait(mode | self.num) + + +def dummy_decode_iface(cached_iface: bytes): + return MockHID(0xDEADBEEF) + + +def getBytes(a): + return hexlify(a).decode("utf-8") + + +class TestTrezorHostProtocol(unittest.TestCase): + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + buffer = bytearray(64) + thp_v1.set_buffer(buffer) + channel._decode_iface = dummy_decode_iface + + def test_simple(self): + self.assertTrue(True) + + def test_channel_allocation(self): + cid_req = ( + b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" + ) + expected_response = "41ffff001e001122334455667712340a0454335731100518002001280128026dcad4ba0000000000000000000000000000000000000000000000000000000000" + test_counter = cache_thp.cid_counter + 1 + self.assertEqual(len(thp_v1.CHANNELS), 0) + self.assertFalse(test_counter in thp_v1.CHANNELS) + gen = thp_v1.thp_main_loop(self.interface, is_debug_session=True) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + gen.send(cid_req) + gen.send(None) + self.assertEqual( + getBytes(self.interface.data[-1]), + expected_response, + ) + self.assertTrue(test_counter in thp_v1.CHANNELS) + self.assertEqual(len(thp_v1.CHANNELS), 1) + + def test_channel_default_state_is_TH1(self): + self.assertEqual(thp_v1.CHANNELS[4660].get_channel_state(), ChannelState.TH1) + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_trezor.wire.thp_v1.py b/core/tests/test_trezor.wire.thp_v1.py index e43c128ba..344f072b2 100644 --- a/core/tests/test_trezor.wire.thp_v1.py +++ b/core/tests/test_trezor.wire.thp_v1.py @@ -1,12 +1,13 @@ from common import * -from ubinascii import hexlify, unhexlify +from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor.wire.thp.writer import REPORT_LENGTH +from ubinascii import hexlify import ustruct from trezor import io, utils from trezor.loop import wait from trezor.utils import chunks from trezor.wire import thp_v1 -from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID from trezor.wire.protocol_common import MessageWithId import trezor.wire.thp.thp_session as THP from trezor.wire.thp import checksum @@ -39,9 +40,7 @@ CONT = 0x80 HEADER_INIT_LENGTH = 5 HEADER_CONT_LENGTH = 3 -INIT_MESSAGE_DATA_LENGTH = ( - thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN -) +INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN def make_header(ctrl_byte, cid, length): @@ -81,7 +80,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): if not utils.USE_THP: import storage.cache_thp # noqa: F401 - def test_simple(self): + def _simple(self): cid_req_header = make_header( ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 ) @@ -116,7 +115,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): # message should have been read into the buffer self.assertEqual(buffer_without_zeroes, message_without_header) - def test_read_one_packet(self): + def _read_one_packet(self): # zero length message - just a header PLAINTEXT = getPlaintext() header = make_header( @@ -142,7 +141,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): # message should have been read into the buffer self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) - def test_read_many_packets(self): + def _read_many_packets(self): message = bytes(range(256)) header = make_header( getPlaintext(), @@ -182,7 +181,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): # message should have been read into the buffer ) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) - def test_read_large_message(self): + def _read_large_message(self): message = b"hello world" header = make_header( getPlaintext(), @@ -218,7 +217,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): # read should have allocated its own buffer and not touch ours self.assertEqual(buffer, b"\x00") - def test_roundtrip(self): + def _roundtrip(self): message_payload = bytes(range(256)) message = MessageWithId( MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) @@ -244,7 +243,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.data, message.data) - def test_write_one_packet(self): + def _write_one_packet(self): message = MessageWithId( MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID) ) @@ -266,7 +265,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): ) self.assertTrue(self.interface.data == [expected_message]) - def test_write_multiple_packets(self): + def _write_multiple_packets(self): message_payload = bytes(range(256)) message = MessageWithId( MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) @@ -310,14 +309,11 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): last_packet = packets[-1] + packets[-2][len(packets[-1]) :] self.assertEqual(last_packet, self.interface.data[-1]) - def test_read_huge_packet(self): + def _read_huge_packet(self): PACKET_COUNT = 1180 # message that takes up 1 180 USB packets message_size = (PACKET_COUNT - 1) * ( - thp_v1._REPORT_LENGTH - - HEADER_CONT_LENGTH - - CHECKSUM_LENGTH - - _MESSAGE_TYPE_LEN + REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN ) + INIT_MESSAGE_DATA_LENGTH # ensure that a message this big won't fit into memory