Channel allocation handling + refactor

M1nd3r/thp5
M1nd3r 2 months ago committed by M1nd3r
parent 372998eb88
commit 7ade02e2b7

@ -201,12 +201,14 @@ trezor.wire.context
import trezor.wire.context
trezor.wire.errors
import trezor.wire.errors
trezor.wire.thp.thp_messages
import trezor.wire.thp.thp_messages
trezor.wire.protocol
import trezor.wire.protocol
trezor.wire.protocol_common
import trezor.wire.protocol_common
trezor.wire.thp_session
import trezor.wire.thp_session
trezor.wire.thp.thp_session
import trezor.wire.thp.thp_session
trezor.wire.thp_v1
import trezor.wire.thp_v1
trezor.workflow

@ -89,8 +89,8 @@ if __debug__:
async def _handle_single_message(
ctx: context.Context, msg: protocol_common.Message, use_workflow: bool
) -> protocol_common.Message | None:
ctx: context.Context, msg: protocol_common.MessageWithId, use_workflow: bool
) -> protocol_common.MessageWithId | None:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case
@ -206,7 +206,7 @@ async def handle_session(
ctx_buffer = WIRE_BUFFER
ctx = context.Context(iface, ctx_buffer, session_id)
next_msg: protocol_common.Message | None = None
next_msg: protocol_common.MessageWithId | None = None
if __debug__ and is_debug_session:
import apps.debug

@ -3,7 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING
from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError
from trezor.wire.protocol_common import MessageWithId, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
@ -23,7 +23,7 @@ class CodecError(WireError):
pass
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
read = loop.wait(iface.iface_num() | io.POLL_READ)
# wait for initial report
@ -65,7 +65,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if read_and_throw_away:
raise CodecError("Message too large")
return Message(mtype, mdata)
return MessageWithId(mtype, mdata)
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
import trezor.wire.protocol as protocol
from trezor import log, loop, protobuf
from .protocol_common import Message
from .protocol_common import MessageWithId
if TYPE_CHECKING:
from trezorio import WireInterface
@ -49,7 +49,7 @@ class UnexpectedMessage(Exception):
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: Message) -> None:
def __init__(self, msg: MessageWithId) -> None:
super().__init__()
self.msg = msg
@ -71,7 +71,7 @@ class Context:
self.buffer = buffer
self.session_id = session_id
def read_from_wire(self) -> Awaitable[Message]:
def read_from_wire(self) -> Awaitable[MessageWithId]:
"""Read a whole message from the wire without parsing it."""
return protocol.read_message(self.iface, self.buffer)
@ -177,7 +177,7 @@ class Context:
msg_session_id = bytearray(self.session_id)
await protocol.write_message(
self.iface,
Message(
MessageWithId(
message_type=msg.MESSAGE_WIRE_TYPE,
message_data=memoryview(buffer)[:msg_size],
session_id=msg_session_id,

@ -2,21 +2,21 @@ from typing import TYPE_CHECKING
from trezor import utils
from trezor.wire import codec_v1, thp_v1
from trezor.wire.protocol_common import Message
from trezor.wire.protocol_common import MessageWithId
if TYPE_CHECKING:
from trezorio import WireInterface
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
if utils.USE_THP:
return await thp_v1.read_message(iface, buffer)
return await codec_v1.read_message(iface, buffer)
async def write_message(iface: WireInterface, message: Message) -> None:
async def write_message(iface: WireInterface, message: MessageWithId) -> None:
if utils.USE_THP:
await thp_v1.write_message(iface, message)
await thp_v1.write_message_with_sync_control(iface, message)
return
await codec_v1.write_message(iface, message.type, message.data)
return

@ -3,11 +3,23 @@ class Message:
self,
message_type: int,
message_data: bytes,
session_id: bytearray | None = None,
) -> None:
self.type = message_type
self.data = message_data
def to_bytes(self):
return self.type.to_bytes(2, "big") + self.data
class MessageWithId(Message):
def __init__(
self,
message_type: int,
message_data: bytes,
session_id: bytearray | None = None,
) -> None:
self.session_id = session_id
super().__init__(message_type, message_data)
class WireError(Exception):

@ -0,0 +1,22 @@
from storage.cache_thp import SessionThpCache
from . import thp_session as THP
def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
if _ack_is_not_expected(session):
return
if _ack_has_incorrect_sync_bit(session, sync_bit):
return
# ACK is expected and it has correct sync bit
THP.sync_set_can_send_message(session, True)
def _ack_is_not_expected(session: SessionThpCache) -> bool:
return THP.sync_can_send_message(session)
def _ack_has_incorrect_sync_bit(session: SessionThpCache, sync_bit: int) -> bool:
return THP.sync_get_send_bit(session) != sync_bit

@ -0,0 +1,15 @@
from micropython import const
from trezor import utils
from trezor.crypto import crc
CHECKSUM_LENGTH = const(4)
def compute(data: bytes | utils.BufferType) -> bytes:
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
data_checksum = compute(data)
return checksum == data_checksum

@ -0,0 +1,71 @@
import ustruct
from storage.cache_thp import BROADCAST_CHANNEL_ID
from ..protocol_common import Message
CONTINUATION_PACKET = 0x80
_ERROR = 0x41
_CHANNEL_ALLOCATION_RES = 0x40
class InitHeader:
format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
)
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
InitHeader.format_str,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(">BH", buffer, buffer_offset, CONTINUATION_PACKET, self.cid)
@classmethod
def get_error_header(cls, cid, length):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_response_header(cls, length):
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
class InterruptingInitPacket:
def __init__(self, report: bytes) -> None:
self.initReport = report
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
)
_ERROR_UNALLOCATED_SESSION = (
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
)
def get_device_properties() -> Message:
return Message(1000, _ENCODED_PROTOBUF_DEVICE_PROPERTIES)
def get_channel_allocation_response(nonce: bytes, new_cid: int) -> bytes:
props_msg = get_device_properties()
return ustruct.pack(">8sH", nonce, new_cid) + props_msg.to_bytes()
def get_error_unallocated_channel() -> bytes:
return _ERROR_UNALLOCATED_SESSION

@ -2,74 +2,35 @@ import ustruct
from micropython import const
from typing import TYPE_CHECKING
import trezor.wire.thp_session as THP
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
from trezor import io, loop, utils
from trezor.crypto import crc
from trezor.wire.protocol_common import Message
from trezor.wire.thp_session import SessionState, ThpError
from .protocol_common import MessageWithId
from .thp import ack_handler, checksum, thp_messages
from .thp import thp_session as THP
from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import CONTINUATION_PACKET, InitHeader, InterruptingInitPacket
from .thp.thp_session import SessionState, ThpError
if TYPE_CHECKING:
from trezorio import WireInterface
_MAX_PAYLOAD_LEN = const(60000)
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
_CHECKSUM_LENGTH = const(4)
_CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x40
_ERROR = 0x41
_CONTINUATION_PACKET = 0x80
_ACK_MESSAGE = 0x20
_HANDSHAKE_INIT = 0x00
_PLAINTEXT = 0x01
ENCRYPTED_TRANSPORT = 0x02
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
)
_UNALLOCATED_SESSION_ERROR = (
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
)
_REPORT_LENGTH = const(64)
_REPORT_INIT_DATA_OFFSET = const(5)
_REPORT_CONT_DATA_OFFSET = const(3)
class InitHeader:
format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
)
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
InitHeader.format_str,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
class InterruptingInitPacket:
def __init__(self, report: bytes) -> None:
self.initReport = report
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
msg = await read_message_or_init_packet(iface, buffer)
while type(msg) is not Message:
while type(msg) is not MessageWithId:
if isinstance(msg, InterruptingInitPacket):
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
else:
@ -79,7 +40,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
async def read_message_or_init_packet(
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
) -> Message | InterruptingInitPacket:
) -> MessageWithId | InterruptingInitPacket:
report = firstReport
while True:
# Wait for an initial report
@ -89,7 +50,7 @@ async def read_message_or_init_packet(
raise ThpError("Reading failed unexpectedly, report is None.")
# Channel multiplexing
ctrl_byte, cid = ustruct.unpack(">BH", report)
ctrl_byte, cid, payload_length = ustruct.unpack(">BHH", report)
if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, report) # TODO await
@ -104,7 +65,6 @@ async def read_message_or_init_packet(
report = None
continue
payload_length = ustruct.unpack(">H", report[3:])[0]
payload = _get_buffer_for_payload(payload_length, buffer)
header = InitHeader(ctrl_byte, cid, payload_length)
@ -114,7 +74,7 @@ async def read_message_or_init_packet(
return interruptingPacket
# Check CRC
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
# checksum is not valid -> ignore message
report = None
continue
@ -141,7 +101,7 @@ async def read_message_or_init_packet(
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
_handle_received_ACK(session, sync_bit)
ack_handler.handle_received_ACK(session, sync_bit)
report = None
continue
@ -215,8 +175,19 @@ async def _buffer_received_data(
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
async def write_message_with_sync_control(
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
) -> None:
session = THP.get_session_from_id(message.session_id)
if session is None:
raise ThpError("Invalid session")
if (not THP.sync_can_send_message(session)) and (not is_retransmission):
raise ThpError("Cannot send another message before ACK is received.")
await write_message(iface, message, is_retransmission)
async def write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
) -> None:
session = THP.get_session_from_id(message.session_id)
if session is None:
@ -245,9 +216,9 @@ async def write_message(
ctrl_byte, 1 - THP.sync_get_send_bit(session)
)
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
await write_to_wire(iface, header, payload + checksum)
header = InitHeader(ctrl_byte, cid, payload_length + CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes() + payload)
await write_to_wire(iface, header, payload + chksum)
# TODO set timeout for retransmission
@ -283,7 +254,9 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
return
async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None:
async def _handle_broadcast(
iface: WireInterface, ctrl_byte, report
) -> MessageWithId | None:
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
@ -291,67 +264,51 @@ async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid")
channel_id = _get_new_channel_id()
THP.create_new_unauthenticated_session(iface, channel_id)
response_data = (
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
response_data = thp_messages.get_channel_allocation_response(nonce, channel_id)
response_header = InitHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
)
response_header = InitHeader(
_CHANNEL_ALLOCATION_RES,
BROADCAST_CHANNEL_ID,
len(response_data) + _CHECKSUM_LENGTH,
)
chksum = checksum.compute(response_header.to_bytes() + response_data)
await write_to_wire(iface, response_header, response_data + chksum)
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
await write_to_wire(iface, response_header, response_data + checksum)
async def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
async def _handle_allocated(
ctrl_byte, session: SessionThpCache, payload
) -> MessageWithId:
# Parameters session and ctrl_byte will be used to determine if the
# communication should be encrypted or not
message_type = ustruct.unpack(">H", payload)[0]
# trim message type and checksum from payload
message_data = payload[2:-_CHECKSUM_LENGTH]
return Message(message_type, message_data, session.session_id)
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
# No ACKs expected
if THP.sync_can_send_message(session):
return
message_data = payload[2:-CHECKSUM_LENGTH]
return MessageWithId(message_type, message_data, session.session_id)
# ACK has incorrect sync bit
if THP.sync_get_send_bit(session) != sync_bit:
return
# ACK is expected and it has correct sync bit
THP.sync_set_can_send_message(session, True)
async def _handle_unallocated(iface, cid) -> Message | None:
data = _UNALLOCATED_SESSION_ERROR
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + data)
await write_to_wire(iface, header, data + checksum)
async def _handle_unallocated(iface, cid) -> MessageWithId | None:
data = thp_messages.get_error_unallocated_channel()
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes() + data)
await write_to_wire(iface, header, data + chksum)
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes())
await write_to_wire(iface, header, checksum)
header = InitHeader(ctrl_byte, cid, CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes())
await write_to_wire(iface, header, chksum)
async def _handle_unexpected_sync_bit(
iface: WireInterface, cid: int, sync_bit: int
) -> Message | None:
) -> MessageWithId | None:
await _sendAck(iface, cid, sync_bit)
# TODO handle cancelation messages and messages on allocated channels without synchronization
@ -362,13 +319,8 @@ def _get_new_channel_id() -> int:
return THP.get_next_channel_id()
def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
data_checksum = _compute_checksum_bytes(data)
return checksum == data_checksum
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
@ -381,7 +333,3 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
def _compute_checksum_bytes(data: bytes | utils.BufferType) -> bytes:
return crc.crc32(data).to_bytes(4, "big")

@ -6,11 +6,11 @@ from trezor import io, utils
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import thp_v1
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import Message
import trezor.wire.thp_session as THP
from micropython import const
from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import MessageWithId
import trezor.wire.thp.thp_session as THP
from trezor.wire.thp import checksum
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
class MockHID:
@ -123,10 +123,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# zero length message - just a header
PLAINTEXT = getPlaintext()
header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + checksum
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + chksum
buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer)
@ -145,16 +145,16 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58)
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def test_read_many_packets(self):
message = bytes(range(256))
header = make_header(
getPlaintext(),
COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data
@ -163,7 +163,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk
for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + checksum,
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
64 - HEADER_CONT_LENGTH,
)
]
@ -185,21 +185,21 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.data, message)
# message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum)
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def test_read_large_message(self):
message = b"hello world"
header = make_header(
getPlaintext(),
COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH,
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
)
packet = (
header
+ MESSAGE_TYPE_BYTES
+ message
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
)
# make sure we fit into one packet, to make this easier
@ -225,7 +225,9 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(buffer, b"\x00")
def test_write_one_packet(self):
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID))
message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
query = gen.send(None)
@ -234,19 +236,19 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
gen.send(None)
header = make_header(
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
expected_message = (
header
+ MESSAGE_TYPE_BYTES
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH)
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
)
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
message_payload = bytes(range(256))
message = Message(
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
@ -254,10 +256,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
header = make_header(
PLAINTEXT_1,
COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
)
cont_header = make_cont_header()
checksum = thp_v1._compute_checksum_bytes(
chksum = checksum.compute(
header + message.type.to_bytes(2, "big") + message.data
)
packets = [
@ -265,7 +267,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
] + [
cont_header + chunk
for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum,
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
)
]
@ -290,7 +292,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
def test_roundtrip(self):
message_payload = bytes(range(256))
message = Message(
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
@ -320,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH
- HEADER_CONT_LENGTH
- _CHECKSUM_LENGTH
- CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH

Loading…
Cancel
Save