Edit thp tests

M1nd3r/thp2
M1nd3r 3 weeks ago
parent df8b367cb2
commit 45f525f916

@ -0,0 +1,71 @@
from common import *
from trezor import io
from trezor.loop import wait
from trezor.wire import thp_v1
from trezor.wire.thp import channel
from storage import cache_thp
from ubinascii import hexlify
from trezor.wire.thp import ChannelState
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
def dummy_decode_iface(cached_iface: bytes):
return MockHID(0xDEADBEEF)
def getBytes(a):
return hexlify(a).decode("utf-8")
class TestTrezorHostProtocol(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
thp_v1.set_buffer(buffer)
channel._decode_iface = dummy_decode_iface
def test_simple(self):
self.assertTrue(True)
def test_channel_allocation(self):
cid_req = (
b"\x40\xff\xff\x00\x0c\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c"
)
expected_response = "41ffff001e001122334455667712340a0454335731100518002001280128026dcad4ba0000000000000000000000000000000000000000000000000000000000"
test_counter = cache_thp.cid_counter + 1
self.assertEqual(len(thp_v1.CHANNELS), 0)
self.assertFalse(test_counter in thp_v1.CHANNELS)
gen = thp_v1.thp_main_loop(self.interface, is_debug_session=True)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
gen.send(cid_req)
gen.send(None)
self.assertEqual(
getBytes(self.interface.data[-1]),
expected_response,
)
self.assertTrue(test_counter in thp_v1.CHANNELS)
self.assertEqual(len(thp_v1.CHANNELS), 1)
def test_channel_default_state_is_TH1(self):
self.assertEqual(thp_v1.CHANNELS[4660].get_channel_state(), ChannelState.TH1)
if __name__ == "__main__":
unittest.main()

@ -1,12 +1,13 @@
from common import *
from ubinascii import hexlify, unhexlify
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor.wire.thp.writer import REPORT_LENGTH
from ubinascii import hexlify
import ustruct
from trezor import io, utils
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import thp_v1
from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import MessageWithId
import trezor.wire.thp.thp_session as THP
from trezor.wire.thp import checksum
@ -39,9 +40,7 @@ CONT = 0x80
HEADER_INIT_LENGTH = 5
HEADER_CONT_LENGTH = 3
INIT_MESSAGE_DATA_LENGTH = (
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
)
INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
def make_header(ctrl_byte, cid, length):
@ -81,7 +80,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
if not utils.USE_THP:
import storage.cache_thp # noqa: F401
def test_simple(self):
def _simple(self):
cid_req_header = make_header(
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
)
@ -116,7 +115,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer
self.assertEqual(buffer_without_zeroes, message_without_header)
def test_read_one_packet(self):
def _read_one_packet(self):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
@ -142,7 +141,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def test_read_many_packets(self):
def _read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
@ -182,7 +181,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def test_read_large_message(self):
def _read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
@ -218,7 +217,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_roundtrip(self):
def _roundtrip(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -244,7 +243,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_write_one_packet(self):
def _write_one_packet(self):
message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
)
@ -266,7 +265,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
)
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
def _write_multiple_packets(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
@ -310,14 +309,11 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_read_huge_packet(self):
def _read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets
message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH
- HEADER_CONT_LENGTH
- CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN
REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH
# ensure that a message this big won't fit into memory

Loading…
Cancel
Save