Edit thp tests

M1nd3r/thp2
M1nd3r 1 month ago
parent df8b367cb2
commit 45f525f916

@ -0,0 +1,71 @@
from common import *
from trezor import io
from trezor.loop import wait
from trezor.wire import thp_v1
from trezor.wire.thp import channel
from storage import cache_thp
from ubinascii import hexlify
from trezor.wire.thp import ChannelState
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def getBytes(a):
return hexlify(a).decode("utf-8")
class TestTrezorHostProtocol(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
thp_v1.set_buffer(buffer)
channel._decode_iface = dummy_decode_iface
def test_simple(self):
self.assertTrue(True)
def test_channel_allocation(self):
cid_req = (
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
expected_response = "41ffff001e001122334455667712340a0454335731100518002001280128026dcad4ba0000000000000000000000000000000000000000000000000000000000"
test_counter = cache_thp.cid_counter + 1
self.assertEqual(len(thp_v1.CHANNELS), 0)
self.assertFalse(test_counter in thp_v1.CHANNELS)
gen = thp_v1.thp_main_loop(self.interface, is_debug_session=True)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req)
gen.send(None)
self.assertEqual(
getBytes(self.interface.data[-1]),
expected_response,
)
self.assertTrue(test_counter in thp_v1.CHANNELS)
self.assertEqual(len(thp_v1.CHANNELS), 1)
def test_channel_default_state_is_TH1(self):
self.assertEqual(thp_v1.CHANNELS[4660].get_channel_state(), ChannelState.TH1)
if __name__ == "__main__":
unittest.main()

@ -1,12 +1,13 @@
from common import * from common import *
from ubinascii import hexlify, unhexlify from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor.wire.thp.writer import REPORT_LENGTH
from ubinascii import hexlify
import ustruct import ustruct
from trezor import io, utils from trezor import io, utils
from trezor.loop import wait from trezor.loop import wait
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import thp_v1 from trezor.wire import thp_v1
from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import MessageWithId from trezor.wire.protocol_common import MessageWithId
import trezor.wire.thp.thp_session as THP import trezor.wire.thp.thp_session as THP
from trezor.wire.thp import checksum from trezor.wire.thp import checksum
@ -39,9 +40,7 @@ CONT = 0x80
HEADER_INIT_LENGTH = 5 HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3 HEADER_CONT_LENGTH = 3
INIT_MESSAGE_DATA_LENGTH = ( INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
)
def make_header(ctrl_byte, cid, length): def make_header(ctrl_byte, cid, length):
@ -81,7 +80,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
if not utils.USE_THP: if not utils.USE_THP:
import storage.cache_thp # noqa: F401 import storage.cache_thp # noqa: F401
def test_simple(self): def _simple(self):
cid_req_header = make_header( cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
) )
@ -116,7 +115,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer # message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header) self.assertEqual(buffer_without_zeroes, message_without_header)
def test_read_one_packet(self): def _read_one_packet(self):
# zero length message - just a header # zero length message - just a header
PLAINTEXT = getPlaintext() PLAINTEXT = getPlaintext()
header = make_header( header = make_header(
@ -142,7 +141,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer # message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def test_read_many_packets(self): def _read_many_packets(self):
message = bytes(range(256)) message = bytes(range(256))
header = make_header( header = make_header(
getPlaintext(), getPlaintext(),
@ -182,7 +181,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer ) # message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def test_read_large_message(self): def _read_large_message(self):
message = b"hello world" message = b"hello world"
header = make_header( header = make_header(
getPlaintext(), getPlaintext(),
@ -218,7 +217,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# read should have allocated its own buffer and not touch ours # read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00") self.assertEqual(buffer, b"\x00")
def test_roundtrip(self): def _roundtrip(self):
message_payload = bytes(range(256)) message_payload = bytes(range(256))
message = MessageWithId( message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -244,7 +243,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data) self.assertEqual(result.data, message.data)
def test_write_one_packet(self): def _write_one_packet(self):
message = MessageWithId( message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
) )
@ -266,7 +265,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
) )
self.assertTrue(self.interface.data == [expected_message]) self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self): def _write_multiple_packets(self):
message_payload = bytes(range(256)) message_payload = bytes(range(256))
message = MessageWithId( message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -310,14 +309,11 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
last_packet = packets[-1] + packets[-2][len(packets[-1]) :] last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1]) self.assertEqual(last_packet, self.interface.data[-1])
def test_read_huge_packet(self): def _read_huge_packet(self):
PACKET_COUNT = 1180 PACKET_COUNT = 1180
# message that takes up 1 180 USB packets # message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * ( message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
- HEADER_CONT_LENGTH
- CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH ) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory # ensure that a message this big won't fit into memory

Loading…
Cancel
Save