|
|
|
@ -1,12 +1,13 @@
|
|
|
|
|
from common import *
|
|
|
|
|
from ubinascii import hexlify, unhexlify
|
|
|
|
|
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
|
|
|
|
from trezor.wire.thp.writer import REPORT_LENGTH
|
|
|
|
|
from ubinascii import hexlify
|
|
|
|
|
import ustruct
|
|
|
|
|
|
|
|
|
|
from trezor import io, utils
|
|
|
|
|
from trezor.loop import wait
|
|
|
|
|
from trezor.utils import chunks
|
|
|
|
|
from trezor.wire import thp_v1
|
|
|
|
|
from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
|
|
|
|
|
from trezor.wire.protocol_common import MessageWithId
|
|
|
|
|
import trezor.wire.thp.thp_session as THP
|
|
|
|
|
from trezor.wire.thp import checksum
|
|
|
|
@ -39,9 +40,7 @@ CONT = 0x80
|
|
|
|
|
|
|
|
|
|
HEADER_INIT_LENGTH = 5
|
|
|
|
|
HEADER_CONT_LENGTH = 3
|
|
|
|
|
INIT_MESSAGE_DATA_LENGTH = (
|
|
|
|
|
thp_v1._REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
|
|
|
|
|
)
|
|
|
|
|
INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_header(ctrl_byte, cid, length):
|
|
|
|
@ -81,7 +80,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
if not utils.USE_THP:
|
|
|
|
|
import storage.cache_thp # noqa: F401
|
|
|
|
|
|
|
|
|
|
def test_simple(self):
|
|
|
|
|
def _simple(self):
|
|
|
|
|
cid_req_header = make_header(
|
|
|
|
|
ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12
|
|
|
|
|
)
|
|
|
|
@ -116,7 +115,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
# message should have been read into the buffer
|
|
|
|
|
self.assertEqual(buffer_without_zeroes, message_without_header)
|
|
|
|
|
|
|
|
|
|
def test_read_one_packet(self):
|
|
|
|
|
def _read_one_packet(self):
|
|
|
|
|
# zero length message - just a header
|
|
|
|
|
PLAINTEXT = getPlaintext()
|
|
|
|
|
header = make_header(
|
|
|
|
@ -142,7 +141,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
# message should have been read into the buffer
|
|
|
|
|
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
|
|
|
|
|
|
|
|
|
|
def test_read_many_packets(self):
|
|
|
|
|
def _read_many_packets(self):
|
|
|
|
|
message = bytes(range(256))
|
|
|
|
|
header = make_header(
|
|
|
|
|
getPlaintext(),
|
|
|
|
@ -182,7 +181,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
# message should have been read into the buffer )
|
|
|
|
|
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
|
|
|
|
|
|
|
|
|
|
def test_read_large_message(self):
|
|
|
|
|
def _read_large_message(self):
|
|
|
|
|
message = b"hello world"
|
|
|
|
|
header = make_header(
|
|
|
|
|
getPlaintext(),
|
|
|
|
@ -218,7 +217,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
# read should have allocated its own buffer and not touch ours
|
|
|
|
|
self.assertEqual(buffer, b"\x00")
|
|
|
|
|
|
|
|
|
|
def test_roundtrip(self):
|
|
|
|
|
def _roundtrip(self):
|
|
|
|
|
message_payload = bytes(range(256))
|
|
|
|
|
message = MessageWithId(
|
|
|
|
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
|
|
|
@ -244,7 +243,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
|
|
|
|
self.assertEqual(result.data, message.data)
|
|
|
|
|
|
|
|
|
|
def test_write_one_packet(self):
|
|
|
|
|
def _write_one_packet(self):
|
|
|
|
|
message = MessageWithId(
|
|
|
|
|
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
|
|
|
|
|
)
|
|
|
|
@ -266,7 +265,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
)
|
|
|
|
|
self.assertTrue(self.interface.data == [expected_message])
|
|
|
|
|
|
|
|
|
|
def test_write_multiple_packets(self):
|
|
|
|
|
def _write_multiple_packets(self):
|
|
|
|
|
message_payload = bytes(range(256))
|
|
|
|
|
message = MessageWithId(
|
|
|
|
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
|
|
|
@ -310,14 +309,11 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
|
|
|
|
self.assertEqual(last_packet, self.interface.data[-1])
|
|
|
|
|
|
|
|
|
|
def test_read_huge_packet(self):
|
|
|
|
|
def _read_huge_packet(self):
|
|
|
|
|
PACKET_COUNT = 1180
|
|
|
|
|
# message that takes up 1 180 USB packets
|
|
|
|
|
message_size = (PACKET_COUNT - 1) * (
|
|
|
|
|
thp_v1._REPORT_LENGTH
|
|
|
|
|
- HEADER_CONT_LENGTH
|
|
|
|
|
- CHECKSUM_LENGTH
|
|
|
|
|
- _MESSAGE_TYPE_LEN
|
|
|
|
|
REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN
|
|
|
|
|
) + INIT_MESSAGE_DATA_LENGTH
|
|
|
|
|
|
|
|
|
|
# ensure that a message this big won't fit into memory
|
|
|
|
|