diff --git a/core/src/all_modules.py b/core/src/all_modules.py index ba7971548..4617636f0 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -201,12 +201,14 @@ trezor.wire.context import trezor.wire.context trezor.wire.errors import trezor.wire.errors +trezor.wire.thp.thp_messages +import trezor.wire.thp.thp_messages trezor.wire.protocol import trezor.wire.protocol trezor.wire.protocol_common import trezor.wire.protocol_common -trezor.wire.thp_session -import trezor.wire.thp_session +trezor.wire.thp.thp_session +import trezor.wire.thp.thp_session trezor.wire.thp_v1 import trezor.wire.thp_v1 trezor.workflow diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 5d73c3dca..77a3dc6d0 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -89,8 +89,8 @@ if __debug__: async def _handle_single_message( - ctx: context.Context, msg: protocol_common.Message, use_workflow: bool -) -> protocol_common.Message | None: + ctx: context.Context, msg: protocol_common.MessageWithId, use_workflow: bool +) -> protocol_common.MessageWithId | None: """Handle a message that was loaded from USB by the caller. Find the appropriate handler, run it and write its result on the wire. In case @@ -206,7 +206,7 @@ async def handle_session( ctx_buffer = WIRE_BUFFER ctx = context.Context(iface, ctx_buffer, session_id) - next_msg: protocol_common.Message | None = None + next_msg: protocol_common.MessageWithId | None = None if __debug__ and is_debug_session: import apps.debug diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index c600201d5..4b0f60e36 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -3,7 +3,7 @@ from micropython import const from typing import TYPE_CHECKING from trezor import io, loop, utils -from trezor.wire.protocol_common import Message, WireError +from trezor.wire.protocol_common import MessageWithId, WireError if TYPE_CHECKING: from trezorio import WireInterface @@ -23,7 +23,7 @@ class CodecError(WireError): pass -async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: +async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: read = loop.wait(iface.iface_num() | io.POLL_READ) # wait for initial report @@ -65,7 +65,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag if read_and_throw_away: raise CodecError("Message too large") - return Message(mtype, mdata) + return MessageWithId(mtype, mdata) async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 5ba9232be..f29b557c0 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING import trezor.wire.protocol as protocol from trezor import log, loop, protobuf -from .protocol_common import Message +from .protocol_common import MessageWithId if TYPE_CHECKING: from trezorio import WireInterface @@ -49,7 +49,7 @@ class UnexpectedMessage(Exception): should be aborted and a new one started as if `msg` was the first message. """ - def __init__(self, msg: Message) -> None: + def __init__(self, msg: MessageWithId) -> None: super().__init__() self.msg = msg @@ -71,7 +71,7 @@ class Context: self.buffer = buffer self.session_id = session_id - def read_from_wire(self) -> Awaitable[Message]: + def read_from_wire(self) -> Awaitable[MessageWithId]: """Read a whole message from the wire without parsing it.""" return protocol.read_message(self.iface, self.buffer) @@ -177,7 +177,7 @@ class Context: msg_session_id = bytearray(self.session_id) await protocol.write_message( self.iface, - Message( + MessageWithId( message_type=msg.MESSAGE_WIRE_TYPE, message_data=memoryview(buffer)[:msg_size], session_id=msg_session_id, diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py index 06732d3b5..de4bc7392 100644 --- a/core/src/trezor/wire/protocol.py +++ b/core/src/trezor/wire/protocol.py @@ -2,21 +2,21 @@ from typing import TYPE_CHECKING from trezor import utils from trezor.wire import codec_v1, thp_v1 -from trezor.wire.protocol_common import Message +from trezor.wire.protocol_common import MessageWithId if TYPE_CHECKING: from trezorio import WireInterface -async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: +async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: if utils.USE_THP: return await thp_v1.read_message(iface, buffer) return await codec_v1.read_message(iface, buffer) -async def write_message(iface: WireInterface, message: Message) -> None: +async def write_message(iface: WireInterface, message: MessageWithId) -> None: if utils.USE_THP: - await thp_v1.write_message(iface, message) + await thp_v1.write_message_with_sync_control(iface, message) return await codec_v1.write_message(iface, message.type, message.data) return diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index f0accc70b..7c7ab80f2 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -3,11 +3,23 @@ class Message: self, message_type: int, message_data: bytes, - session_id: bytearray | None = None, ) -> None: self.type = message_type self.data = message_data + + def to_bytes(self): + return self.type.to_bytes(2, "big") + self.data + + +class MessageWithId(Message): + def __init__( + self, + message_type: int, + message_data: bytes, + session_id: bytearray | None = None, + ) -> None: self.session_id = session_id + super().__init__(message_type, message_data) class WireError(Exception): diff --git a/core/src/trezor/wire/thp/ack_handler.py b/core/src/trezor/wire/thp/ack_handler.py new file mode 100644 index 000000000..578346553 --- /dev/null +++ b/core/src/trezor/wire/thp/ack_handler.py @@ -0,0 +1,22 @@ +from storage.cache_thp import SessionThpCache + +from . import thp_session as THP + + +def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None: + + if _ack_is_not_expected(session): + return + if _ack_has_incorrect_sync_bit(session, sync_bit): + return + + # ACK is expected and it has correct sync bit + THP.sync_set_can_send_message(session, True) + + +def _ack_is_not_expected(session: SessionThpCache) -> bool: + return THP.sync_can_send_message(session) + + +def _ack_has_incorrect_sync_bit(session: SessionThpCache, sync_bit: int) -> bool: + return THP.sync_get_send_bit(session) != sync_bit diff --git a/core/src/trezor/wire/thp/checksum.py b/core/src/trezor/wire/thp/checksum.py new file mode 100644 index 000000000..0a6dae2d8 --- /dev/null +++ b/core/src/trezor/wire/thp/checksum.py @@ -0,0 +1,15 @@ +from micropython import const + +from trezor import utils +from trezor.crypto import crc + +CHECKSUM_LENGTH = const(4) + + +def compute(data: bytes | utils.BufferType) -> bytes: + return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool: + data_checksum = compute(data) + return checksum == data_checksum diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py new file mode 100644 index 000000000..2837a0eda --- /dev/null +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -0,0 +1,71 @@ +import ustruct + +from storage.cache_thp import BROADCAST_CHANNEL_ID + +from ..protocol_common import Message + +CONTINUATION_PACKET = 0x80 +_ERROR = 0x41 +_CHANNEL_ALLOCATION_RES = 0x40 + + +class InitHeader: + format_str = ">BHH" + + def __init__(self, ctrl_byte, cid, length) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.length = length + + def to_bytes(self) -> bytes: + return ustruct.pack( + InitHeader.format_str, self.ctrl_byte, self.cid, self.length + ) + + def pack_to_buffer(self, buffer, buffer_offset=0) -> None: + ustruct.pack_into( + InitHeader.format_str, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.length, + ) + + def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None: + ustruct.pack_into(">BH", buffer, buffer_offset, CONTINUATION_PACKET, self.cid) + + @classmethod + def get_error_header(cls, cid, length): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_response_header(cls, length): + return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length) + + +class InterruptingInitPacket: + def __init__(self, report: bytes) -> None: + self.initReport = report + + +_ENCODED_PROTOBUF_DEVICE_PROPERTIES = ( + b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02" +) + +_ERROR_UNALLOCATED_SESSION = ( + b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e" +) + + +def get_device_properties() -> Message: + return Message(1000, _ENCODED_PROTOBUF_DEVICE_PROPERTIES) + + +def get_channel_allocation_response(nonce: bytes, new_cid: int) -> bytes: + props_msg = get_device_properties() + return ustruct.pack(">8sH", nonce, new_cid) + props_msg.to_bytes() + + +def get_error_unallocated_channel() -> bytes: + return _ERROR_UNALLOCATED_SESSION diff --git a/core/src/trezor/wire/thp_session.py b/core/src/trezor/wire/thp/thp_session.py similarity index 100% rename from core/src/trezor/wire/thp_session.py rename to core/src/trezor/wire/thp/thp_session.py diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 8fbb0c850..f0ceee329 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -2,74 +2,35 @@ import ustruct from micropython import const from typing import TYPE_CHECKING -import trezor.wire.thp_session as THP from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache from trezor import io, loop, utils -from trezor.crypto import crc -from trezor.wire.protocol_common import Message -from trezor.wire.thp_session import SessionState, ThpError + +from .protocol_common import MessageWithId +from .thp import ack_handler, checksum, thp_messages +from .thp import thp_session as THP +from .thp.checksum import CHECKSUM_LENGTH +from .thp.thp_messages import CONTINUATION_PACKET, InitHeader, InterruptingInitPacket +from .thp.thp_session import SessionState, ThpError if TYPE_CHECKING: from trezorio import WireInterface _MAX_PAYLOAD_LEN = const(60000) _MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value -_CHECKSUM_LENGTH = const(4) _CHANNEL_ALLOCATION_REQ = 0x40 -_CHANNEL_ALLOCATION_RES = 0x40 -_ERROR = 0x41 -_CONTINUATION_PACKET = 0x80 _ACK_MESSAGE = 0x20 _HANDSHAKE_INIT = 0x00 _PLAINTEXT = 0x01 ENCRYPTED_TRANSPORT = 0x02 -_ENCODED_PROTOBUF_DEVICE_PROPERTIES = ( - b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02" -) -_UNALLOCATED_SESSION_ERROR = ( - b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e" -) _REPORT_LENGTH = const(64) _REPORT_INIT_DATA_OFFSET = const(5) _REPORT_CONT_DATA_OFFSET = const(3) -class InitHeader: - format_str = ">BHH" - - def __init__(self, ctrl_byte, cid, length) -> None: - self.ctrl_byte = ctrl_byte - self.cid = cid - self.length = length - - def to_bytes(self) -> bytes: - return ustruct.pack( - InitHeader.format_str, self.ctrl_byte, self.cid, self.length - ) - - def pack_to_buffer(self, buffer, buffer_offset=0) -> None: - ustruct.pack_into( - InitHeader.format_str, - buffer, - buffer_offset, - self.ctrl_byte, - self.cid, - self.length, - ) - - def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None: - ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid) - - -class InterruptingInitPacket: - def __init__(self, report: bytes) -> None: - self.initReport = report - - -async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: +async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: msg = await read_message_or_init_packet(iface, buffer) - while type(msg) is not Message: + while type(msg) is not MessageWithId: if isinstance(msg, InterruptingInitPacket): msg = await read_message_or_init_packet(iface, buffer, msg.initReport) else: @@ -79,7 +40,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag async def read_message_or_init_packet( iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None -) -> Message | InterruptingInitPacket: +) -> MessageWithId | InterruptingInitPacket: report = firstReport while True: # Wait for an initial report @@ -89,7 +50,7 @@ async def read_message_or_init_packet( raise ThpError("Reading failed unexpectedly, report is None.") # Channel multiplexing - ctrl_byte, cid = ustruct.unpack(">BH", report) + ctrl_byte, cid, payload_length = ustruct.unpack(">BHH", report) if cid == BROADCAST_CHANNEL_ID: await _handle_broadcast(iface, ctrl_byte, report) # TODO await @@ -104,7 +65,6 @@ async def read_message_or_init_packet( report = None continue - payload_length = ustruct.unpack(">H", report[3:])[0] payload = _get_buffer_for_payload(payload_length, buffer) header = InitHeader(ctrl_byte, cid, payload_length) @@ -114,7 +74,7 @@ async def read_message_or_init_packet( return interruptingPacket # Check CRC - if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]): + if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): # checksum is not valid -> ignore message report = None continue @@ -141,7 +101,7 @@ async def read_message_or_init_packet( # 1: Handle ACKs if _is_ctrl_byte_ack(ctrl_byte): - _handle_received_ACK(session, sync_bit) + ack_handler.handle_received_ACK(session, sync_bit) report = None continue @@ -215,8 +175,19 @@ async def _buffer_received_data( nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET) +async def write_message_with_sync_control( + iface: WireInterface, message: MessageWithId, is_retransmission: bool = False +) -> None: + session = THP.get_session_from_id(message.session_id) + if session is None: + raise ThpError("Invalid session") + if (not THP.sync_can_send_message(session)) and (not is_retransmission): + raise ThpError("Cannot send another message before ACK is received.") + await write_message(iface, message, is_retransmission) + + async def write_message( - iface: WireInterface, message: Message, is_retransmission: bool = False + iface: WireInterface, message: MessageWithId, is_retransmission: bool = False ) -> None: session = THP.get_session_from_id(message.session_id) if session is None: @@ -245,9 +216,9 @@ async def write_message( ctrl_byte, 1 - THP.sync_get_send_bit(session) ) - header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH) - checksum = _compute_checksum_bytes(header.to_bytes() + payload) - await write_to_wire(iface, header, payload + checksum) + header = InitHeader(ctrl_byte, cid, payload_length + CHECKSUM_LENGTH) + chksum = checksum.compute(header.to_bytes() + payload) + await write_to_wire(iface, header, payload + chksum) # TODO set timeout for retransmission @@ -283,7 +254,9 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None: return -async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None: +async def _handle_broadcast( + iface: WireInterface, ctrl_byte, report +) -> MessageWithId | None: if ctrl_byte != _CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in broadcast channel packet") @@ -291,67 +264,51 @@ async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length) payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH) - if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]): + if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") channel_id = _get_new_channel_id() THP.create_new_unauthenticated_session(iface, channel_id) - response_data = ( - ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES + response_data = thp_messages.get_channel_allocation_response(nonce, channel_id) + response_header = InitHeader.get_channel_allocation_response_header( + len(response_data) + CHECKSUM_LENGTH, ) - response_header = InitHeader( - _CHANNEL_ALLOCATION_RES, - BROADCAST_CHANNEL_ID, - len(response_data) + _CHECKSUM_LENGTH, - ) + chksum = checksum.compute(response_header.to_bytes() + response_data) + await write_to_wire(iface, response_header, response_data + chksum) - checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data) - await write_to_wire(iface, response_header, response_data + checksum) - -async def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message: +async def _handle_allocated( + ctrl_byte, session: SessionThpCache, payload +) -> MessageWithId: # Parameters session and ctrl_byte will be used to determine if the # communication should be encrypted or not message_type = ustruct.unpack(">H", payload)[0] # trim message type and checksum from payload - message_data = payload[2:-_CHECKSUM_LENGTH] - return Message(message_type, message_data, session.session_id) - - -def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None: - # No ACKs expected - if THP.sync_can_send_message(session): - return + message_data = payload[2:-CHECKSUM_LENGTH] + return MessageWithId(message_type, message_data, session.session_id) - # ACK has incorrect sync bit - if THP.sync_get_send_bit(session) != sync_bit: - return - # ACK is expected and it has correct sync bit - THP.sync_set_can_send_message(session, True) - - -async def _handle_unallocated(iface, cid) -> Message | None: - data = _UNALLOCATED_SESSION_ERROR - header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH) - checksum = _compute_checksum_bytes(header.to_bytes() + data) - await write_to_wire(iface, header, data + checksum) +async def _handle_unallocated(iface, cid) -> MessageWithId | None: + data = thp_messages.get_error_unallocated_channel() + header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) + chksum = checksum.compute(header.to_bytes() + data) + await write_to_wire(iface, header, data + chksum) async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None: ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit) - header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH) - checksum = _compute_checksum_bytes(header.to_bytes()) - await write_to_wire(iface, header, checksum) + header = InitHeader(ctrl_byte, cid, CHECKSUM_LENGTH) + chksum = checksum.compute(header.to_bytes()) + await write_to_wire(iface, header, chksum) async def _handle_unexpected_sync_bit( iface: WireInterface, cid: int, sync_bit: int -) -> Message | None: +) -> MessageWithId | None: await _sendAck(iface, cid, sync_bit) # TODO handle cancelation messages and messages on allocated channels without synchronization @@ -362,13 +319,8 @@ def _get_new_channel_id() -> int: return THP.get_next_channel_id() -def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool: - data_checksum = _compute_checksum_bytes(data) - return checksum == data_checksum - - def _is_ctrl_byte_continuation(ctrl_byte) -> bool: - return ctrl_byte & 0x80 == _CONTINUATION_PACKET + return ctrl_byte & 0x80 == CONTINUATION_PACKET def _is_ctrl_byte_ack(ctrl_byte) -> bool: @@ -381,7 +333,3 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit): if sync_bit == 1: return ctrl_byte | 0x10 raise ThpError("Unexpected synchronization bit") - - -def _compute_checksum_bytes(data: bytes | utils.BufferType) -> bytes: - return crc.crc32(data).to_bytes(4, "big") diff --git a/core/tests/test_trezor.wire.thp_v1.py b/core/tests/test_trezor.wire.thp_v1.py index 1c8b5b605..4bea7dd8a 100644 --- a/core/tests/test_trezor.wire.thp_v1.py +++ b/core/tests/test_trezor.wire.thp_v1.py @@ -6,11 +6,11 @@ from trezor import io, utils from trezor.loop import wait from trezor.utils import chunks from trezor.wire import thp_v1 -from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID -from trezor.wire.protocol_common import Message -import trezor.wire.thp_session as THP - -from micropython import const +from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID +from trezor.wire.protocol_common import MessageWithId +import trezor.wire.thp.thp_session as THP +from trezor.wire.thp import checksum +from trezor.wire.thp.checksum import CHECKSUM_LENGTH class MockHID: @@ -123,10 +123,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): # zero length message - just a header PLAINTEXT = getPlaintext() header = make_header( - PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH + PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH ) - checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES) - message = header + MESSAGE_TYPE_BYTES + checksum + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES) + message = header + MESSAGE_TYPE_BYTES + chksum buffer = bytearray(64) gen = thp_v1.read_message(self.interface, buffer) @@ -145,16 +145,16 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): self.assertEqual(result.data, b"") # message should have been read into the buffer - self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58) + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) def test_read_many_packets(self): message = bytes(range(256)) header = make_header( getPlaintext(), COMMON_CID, - len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH, + len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, ) - checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message) + chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message) # message = MESSAGE_TYPE_BYTES + message + checksum # first packet is init header + 59 bytes of data @@ -163,7 +163,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [ cont_header + chunk for chunk in chunks( - message[INIT_MESSAGE_DATA_LENGTH:] + checksum, + message[INIT_MESSAGE_DATA_LENGTH:] + chksum, 64 - HEADER_CONT_LENGTH, ) ] @@ -185,21 +185,21 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): self.assertEqual(result.data, message) # message should have been read into the buffer ) - self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum) + self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) def test_read_large_message(self): message = b"hello world" header = make_header( getPlaintext(), COMMON_CID, - _MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH, + _MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH, ) packet = ( header + MESSAGE_TYPE_BYTES + message - + thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message) + + checksum.compute(header + MESSAGE_TYPE_BYTES + message) ) # make sure we fit into one packet, to make this easier @@ -225,7 +225,9 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): self.assertEqual(buffer, b"\x00") def test_write_one_packet(self): - message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)) + message = MessageWithId( + MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID) + ) gen = thp_v1.write_message(self.interface, message) query = gen.send(None) @@ -234,19 +236,19 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): gen.send(None) header = make_header( - PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH + PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH ) expected_message = ( header + MESSAGE_TYPE_BYTES - + thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES) - + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH) + + checksum.compute(header + MESSAGE_TYPE_BYTES) + + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH) ) self.assertTrue(self.interface.data == [expected_message]) def test_write_multiple_packets(self): message_payload = bytes(range(256)) - message = Message( + message = MessageWithId( MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) ) gen = thp_v1.write_message(self.interface, message) @@ -254,10 +256,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): header = make_header( PLAINTEXT_1, COMMON_CID, - len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH, + len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, ) cont_header = make_cont_header() - checksum = thp_v1._compute_checksum_bytes( + chksum = checksum.compute( header + message.type.to_bytes(2, "big") + message.data ) packets = [ @@ -265,7 +267,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): ] + [ cont_header + chunk for chunk in chunks( - message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum, + message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum, thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH, ) ] @@ -290,7 +292,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): def test_roundtrip(self): message_payload = bytes(range(256)) - message = Message( + message = MessageWithId( MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) ) gen = thp_v1.write_message(self.interface, message) @@ -320,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): message_size = (PACKET_COUNT - 1) * ( thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH - - _CHECKSUM_LENGTH + - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN ) + INIT_MESSAGE_DATA_LENGTH