1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-24 13:22:05 +00:00

Channel allocation handling + refactor

This commit is contained in:
M1nd3r 2024-03-14 14:07:47 +01:00 committed by M1nd3r
parent 84c069cdbf
commit aa346086c3
12 changed files with 219 additions and 147 deletions

View File

@ -201,12 +201,14 @@ trezor.wire.context
import trezor.wire.context import trezor.wire.context
trezor.wire.errors trezor.wire.errors
import trezor.wire.errors import trezor.wire.errors
trezor.wire.thp.thp_messages
import trezor.wire.thp.thp_messages
trezor.wire.protocol trezor.wire.protocol
import trezor.wire.protocol import trezor.wire.protocol
trezor.wire.protocol_common trezor.wire.protocol_common
import trezor.wire.protocol_common import trezor.wire.protocol_common
trezor.wire.thp_session trezor.wire.thp.thp_session
import trezor.wire.thp_session import trezor.wire.thp.thp_session
trezor.wire.thp_v1 trezor.wire.thp_v1
import trezor.wire.thp_v1 import trezor.wire.thp_v1
trezor.workflow trezor.workflow

View File

@ -89,8 +89,8 @@ if __debug__:
async def _handle_single_message( async def _handle_single_message(
ctx: context.Context, msg: protocol_common.Message, use_workflow: bool ctx: context.Context, msg: protocol_common.MessageWithId, use_workflow: bool
) -> protocol_common.Message | None: ) -> protocol_common.MessageWithId | None:
"""Handle a message that was loaded from USB by the caller. """Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case Find the appropriate handler, run it and write its result on the wire. In case
@ -206,7 +206,7 @@ async def handle_session(
ctx_buffer = WIRE_BUFFER ctx_buffer = WIRE_BUFFER
ctx = context.Context(iface, ctx_buffer, session_id) ctx = context.Context(iface, ctx_buffer, session_id)
next_msg: protocol_common.Message | None = None next_msg: protocol_common.MessageWithId | None = None
if __debug__ and is_debug_session: if __debug__ and is_debug_session:
import apps.debug import apps.debug

View File

@ -3,7 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import io, loop, utils from trezor import io, loop, utils
from trezor.wire.protocol_common import Message, WireError from trezor.wire.protocol_common import MessageWithId, WireError
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
@ -23,7 +23,7 @@ class CodecError(WireError):
pass pass
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
read = loop.wait(iface.iface_num() | io.POLL_READ) read = loop.wait(iface.iface_num() | io.POLL_READ)
# wait for initial report # wait for initial report
@ -65,7 +65,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if read_and_throw_away: if read_and_throw_away:
raise CodecError("Message too large") raise CodecError("Message too large")
return Message(mtype, mdata) return MessageWithId(mtype, mdata)
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None: async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
import trezor.wire.protocol as protocol import trezor.wire.protocol as protocol
from trezor import log, loop, protobuf from trezor import log, loop, protobuf
from .protocol_common import Message from .protocol_common import MessageWithId
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
@ -49,7 +49,7 @@ class UnexpectedMessage(Exception):
should be aborted and a new one started as if `msg` was the first message. should be aborted and a new one started as if `msg` was the first message.
""" """
def __init__(self, msg: Message) -> None: def __init__(self, msg: MessageWithId) -> None:
super().__init__() super().__init__()
self.msg = msg self.msg = msg
@ -71,7 +71,7 @@ class Context:
self.buffer = buffer self.buffer = buffer
self.session_id = session_id self.session_id = session_id
def read_from_wire(self) -> Awaitable[Message]: def read_from_wire(self) -> Awaitable[MessageWithId]:
"""Read a whole message from the wire without parsing it.""" """Read a whole message from the wire without parsing it."""
return protocol.read_message(self.iface, self.buffer) return protocol.read_message(self.iface, self.buffer)
@ -177,7 +177,7 @@ class Context:
msg_session_id = bytearray(self.session_id) msg_session_id = bytearray(self.session_id)
await protocol.write_message( await protocol.write_message(
self.iface, self.iface,
Message( MessageWithId(
message_type=msg.MESSAGE_WIRE_TYPE, message_type=msg.MESSAGE_WIRE_TYPE,
message_data=memoryview(buffer)[:msg_size], message_data=memoryview(buffer)[:msg_size],
session_id=msg_session_id, session_id=msg_session_id,

View File

@ -2,21 +2,21 @@ from typing import TYPE_CHECKING
from trezor import utils from trezor import utils
from trezor.wire import codec_v1, thp_v1 from trezor.wire import codec_v1, thp_v1
from trezor.wire.protocol_common import Message from trezor.wire.protocol_common import MessageWithId
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
if utils.USE_THP: if utils.USE_THP:
return await thp_v1.read_message(iface, buffer) return await thp_v1.read_message(iface, buffer)
return await codec_v1.read_message(iface, buffer) return await codec_v1.read_message(iface, buffer)
async def write_message(iface: WireInterface, message: Message) -> None: async def write_message(iface: WireInterface, message: MessageWithId) -> None:
if utils.USE_THP: if utils.USE_THP:
await thp_v1.write_message(iface, message) await thp_v1.write_message_with_sync_control(iface, message)
return return
await codec_v1.write_message(iface, message.type, message.data) await codec_v1.write_message(iface, message.type, message.data)
return return

View File

@ -1,13 +1,25 @@
class Message: class Message:
def __init__(
self,
message_type: int,
message_data: bytes,
) -> None:
self.type = message_type
self.data = message_data
def to_bytes(self):
return self.type.to_bytes(2, "big") + self.data
class MessageWithId(Message):
def __init__( def __init__(
self, self,
message_type: int, message_type: int,
message_data: bytes, message_data: bytes,
session_id: bytearray | None = None, session_id: bytearray | None = None,
) -> None: ) -> None:
self.type = message_type
self.data = message_data
self.session_id = session_id self.session_id = session_id
super().__init__(message_type, message_data)
class WireError(Exception): class WireError(Exception):

View File

@ -0,0 +1,22 @@
from storage.cache_thp import SessionThpCache
from . import thp_session as THP
def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
if _ack_is_not_expected(session):
return
if _ack_has_incorrect_sync_bit(session, sync_bit):
return
# ACK is expected and it has correct sync bit
THP.sync_set_can_send_message(session, True)
def _ack_is_not_expected(session: SessionThpCache) -> bool:
return THP.sync_can_send_message(session)
def _ack_has_incorrect_sync_bit(session: SessionThpCache, sync_bit: int) -> bool:
return THP.sync_get_send_bit(session) != sync_bit

View File

@ -0,0 +1,15 @@
from micropython import const
from trezor import utils
from trezor.crypto import crc
CHECKSUM_LENGTH = const(4)
def compute(data: bytes | utils.BufferType) -> bytes:
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,71 @@
import ustruct
from storage.cache_thp import BROADCAST_CHANNEL_ID
from ..protocol_common import Message
CONTINUATION_PACKET = 0x80
_ERROR = 0x41
_CHANNEL_ALLOCATION_RES = 0x40
class InitHeader:
format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
)
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
InitHeader.format_str,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(">BH", buffer, buffer_offset, CONTINUATION_PACKET, self.cid)
@classmethod
def get_error_header(cls, cid, length):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_response_header(cls, length):
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
class InterruptingInitPacket:
def __init__(self, report: bytes) -> None:
self.initReport = report
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
)
_ERROR_UNALLOCATED_SESSION = (
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
)
def get_device_properties() -> Message:
return Message(1000, _ENCODED_PROTOBUF_DEVICE_PROPERTIES)
def get_channel_allocation_response(nonce: bytes, new_cid: int) -> bytes:
props_msg = get_device_properties()
return ustruct.pack(">8sH", nonce, new_cid) + props_msg.to_bytes()
def get_error_unallocated_channel() -> bytes:
return _ERROR_UNALLOCATED_SESSION

View File

@ -2,74 +2,35 @@ import ustruct
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import trezor.wire.thp_session as THP
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
from trezor import io, loop, utils from trezor import io, loop, utils
from trezor.crypto import crc
from trezor.wire.protocol_common import Message from .protocol_common import MessageWithId
from trezor.wire.thp_session import SessionState, ThpError from .thp import ack_handler, checksum, thp_messages
from .thp import thp_session as THP
from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import CONTINUATION_PACKET, InitHeader, InterruptingInitPacket
from .thp.thp_session import SessionState, ThpError
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
_MAX_PAYLOAD_LEN = const(60000) _MAX_PAYLOAD_LEN = const(60000)
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value _MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
_CHECKSUM_LENGTH = const(4)
_CHANNEL_ALLOCATION_REQ = 0x40 _CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x40
_ERROR = 0x41
_CONTINUATION_PACKET = 0x80
_ACK_MESSAGE = 0x20 _ACK_MESSAGE = 0x20
_HANDSHAKE_INIT = 0x00 _HANDSHAKE_INIT = 0x00
_PLAINTEXT = 0x01 _PLAINTEXT = 0x01
ENCRYPTED_TRANSPORT = 0x02 ENCRYPTED_TRANSPORT = 0x02
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
)
_UNALLOCATED_SESSION_ERROR = (
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
)
_REPORT_LENGTH = const(64) _REPORT_LENGTH = const(64)
_REPORT_INIT_DATA_OFFSET = const(5) _REPORT_INIT_DATA_OFFSET = const(5)
_REPORT_CONT_DATA_OFFSET = const(3) _REPORT_CONT_DATA_OFFSET = const(3)
class InitHeader: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
def to_bytes(self) -> bytes:
return ustruct.pack(
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
)
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(
InitHeader.format_str,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.length,
)
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
class InterruptingInitPacket:
def __init__(self, report: bytes) -> None:
self.initReport = report
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
msg = await read_message_or_init_packet(iface, buffer) msg = await read_message_or_init_packet(iface, buffer)
while type(msg) is not Message: while type(msg) is not MessageWithId:
if isinstance(msg, InterruptingInitPacket): if isinstance(msg, InterruptingInitPacket):
msg = await read_message_or_init_packet(iface, buffer, msg.initReport) msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
else: else:
@ -79,7 +40,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
async def read_message_or_init_packet( async def read_message_or_init_packet(
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
) -> Message | InterruptingInitPacket: ) -> MessageWithId | InterruptingInitPacket:
report = firstReport report = firstReport
while True: while True:
# Wait for an initial report # Wait for an initial report
@ -89,7 +50,7 @@ async def read_message_or_init_packet(
raise ThpError("Reading failed unexpectedly, report is None.") raise ThpError("Reading failed unexpectedly, report is None.")
# Channel multiplexing # Channel multiplexing
ctrl_byte, cid = ustruct.unpack(">BH", report) ctrl_byte, cid, payload_length = ustruct.unpack(">BHH", report)
if cid == BROADCAST_CHANNEL_ID: if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, report) # TODO await await _handle_broadcast(iface, ctrl_byte, report) # TODO await
@ -104,7 +65,6 @@ async def read_message_or_init_packet(
report = None report = None
continue continue
payload_length = ustruct.unpack(">H", report[3:])[0]
payload = _get_buffer_for_payload(payload_length, buffer) payload = _get_buffer_for_payload(payload_length, buffer)
header = InitHeader(ctrl_byte, cid, payload_length) header = InitHeader(ctrl_byte, cid, payload_length)
@ -114,7 +74,7 @@ async def read_message_or_init_packet(
return interruptingPacket return interruptingPacket
# Check CRC # Check CRC
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]): if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
# checksum is not valid -> ignore message # checksum is not valid -> ignore message
report = None report = None
continue continue
@ -141,7 +101,7 @@ async def read_message_or_init_packet(
# 1: Handle ACKs # 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte): if _is_ctrl_byte_ack(ctrl_byte):
_handle_received_ACK(session, sync_bit) ack_handler.handle_received_ACK(session, sync_bit)
report = None report = None
continue continue
@ -215,8 +175,19 @@ async def _buffer_received_data(
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET) nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
async def write_message_with_sync_control(
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
) -> None:
session = THP.get_session_from_id(message.session_id)
if session is None:
raise ThpError("Invalid session")
if (not THP.sync_can_send_message(session)) and (not is_retransmission):
raise ThpError("Cannot send another message before ACK is received.")
await write_message(iface, message, is_retransmission)
async def write_message( async def write_message(
iface: WireInterface, message: Message, is_retransmission: bool = False iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
) -> None: ) -> None:
session = THP.get_session_from_id(message.session_id) session = THP.get_session_from_id(message.session_id)
if session is None: if session is None:
@ -245,9 +216,9 @@ async def write_message(
ctrl_byte, 1 - THP.sync_get_send_bit(session) ctrl_byte, 1 - THP.sync_get_send_bit(session)
) )
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH) header = InitHeader(ctrl_byte, cid, payload_length + CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + payload) chksum = checksum.compute(header.to_bytes() + payload)
await write_to_wire(iface, header, payload + checksum) await write_to_wire(iface, header, payload + chksum)
# TODO set timeout for retransmission # TODO set timeout for retransmission
@ -283,7 +254,9 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
return return
async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None: async def _handle_broadcast(
iface: WireInterface, ctrl_byte, report
) -> MessageWithId | None:
if ctrl_byte != _CHANNEL_ALLOCATION_REQ: if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet") raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
@ -291,67 +264,51 @@ async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length) header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH) payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]): if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid") raise ThpError("Checksum is not valid")
channel_id = _get_new_channel_id() channel_id = _get_new_channel_id()
THP.create_new_unauthenticated_session(iface, channel_id) THP.create_new_unauthenticated_session(iface, channel_id)
response_data = ( response_data = thp_messages.get_channel_allocation_response(nonce, channel_id)
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES response_header = InitHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
) )
response_header = InitHeader( chksum = checksum.compute(response_header.to_bytes() + response_data)
_CHANNEL_ALLOCATION_RES, await write_to_wire(iface, response_header, response_data + chksum)
BROADCAST_CHANNEL_ID,
len(response_data) + _CHECKSUM_LENGTH,
)
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
await write_to_wire(iface, response_header, response_data + checksum)
async def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message: async def _handle_allocated(
ctrl_byte, session: SessionThpCache, payload
) -> MessageWithId:
# Parameters session and ctrl_byte will be used to determine if the # Parameters session and ctrl_byte will be used to determine if the
# communication should be encrypted or not # communication should be encrypted or not
message_type = ustruct.unpack(">H", payload)[0] message_type = ustruct.unpack(">H", payload)[0]
# trim message type and checksum from payload # trim message type and checksum from payload
message_data = payload[2:-_CHECKSUM_LENGTH] message_data = payload[2:-CHECKSUM_LENGTH]
return Message(message_type, message_data, session.session_id) return MessageWithId(message_type, message_data, session.session_id)
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None: async def _handle_unallocated(iface, cid) -> MessageWithId | None:
# No ACKs expected data = thp_messages.get_error_unallocated_channel()
if THP.sync_can_send_message(session): header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
return chksum = checksum.compute(header.to_bytes() + data)
await write_to_wire(iface, header, data + chksum)
# ACK has incorrect sync bit
if THP.sync_get_send_bit(session) != sync_bit:
return
# ACK is expected and it has correct sync bit
THP.sync_set_can_send_message(session, True)
async def _handle_unallocated(iface, cid) -> Message | None:
data = _UNALLOCATED_SESSION_ERROR
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes() + data)
await write_to_wire(iface, header, data + checksum)
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None: async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit) ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH) header = InitHeader(ctrl_byte, cid, CHECKSUM_LENGTH)
checksum = _compute_checksum_bytes(header.to_bytes()) chksum = checksum.compute(header.to_bytes())
await write_to_wire(iface, header, checksum) await write_to_wire(iface, header, chksum)
async def _handle_unexpected_sync_bit( async def _handle_unexpected_sync_bit(
iface: WireInterface, cid: int, sync_bit: int iface: WireInterface, cid: int, sync_bit: int
) -> Message | None: ) -> MessageWithId | None:
await _sendAck(iface, cid, sync_bit) await _sendAck(iface, cid, sync_bit)
# TODO handle cancelation messages and messages on allocated channels without synchronization # TODO handle cancelation messages and messages on allocated channels without synchronization
@ -362,13 +319,8 @@ def _get_new_channel_id() -> int:
return THP.get_next_channel_id() return THP.get_next_channel_id()
def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
data_checksum = _compute_checksum_bytes(data)
return checksum == data_checksum
def _is_ctrl_byte_continuation(ctrl_byte) -> bool: def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
return ctrl_byte & 0x80 == _CONTINUATION_PACKET return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_ack(ctrl_byte) -> bool: def _is_ctrl_byte_ack(ctrl_byte) -> bool:
@ -381,7 +333,3 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
if sync_bit == 1: if sync_bit == 1:
return ctrl_byte | 0x10 return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit") raise ThpError("Unexpected synchronization bit")
def _compute_checksum_bytes(data: bytes | utils.BufferType) -> bytes:
return crc.crc32(data).to_bytes(4, "big")

View File

@ -6,11 +6,11 @@ from trezor import io, utils
from trezor.loop import wait from trezor.loop import wait
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import thp_v1 from trezor.wire import thp_v1
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
from trezor.wire.protocol_common import Message from trezor.wire.protocol_common import MessageWithId
import trezor.wire.thp_session as THP import trezor.wire.thp.thp_session as THP
from trezor.wire.thp import checksum
from micropython import const from trezor.wire.thp.checksum import CHECKSUM_LENGTH
class MockHID: class MockHID:
@ -123,10 +123,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# zero length message - just a header # zero length message - just a header
PLAINTEXT = getPlaintext() PLAINTEXT = getPlaintext()
header = make_header( header = make_header(
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
) )
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES) chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
message = header + MESSAGE_TYPE_BYTES + checksum message = header + MESSAGE_TYPE_BYTES + chksum
buffer = bytearray(64) buffer = bytearray(64)
gen = thp_v1.read_message(self.interface, buffer) gen = thp_v1.read_message(self.interface, buffer)
@ -145,16 +145,16 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.data, b"") self.assertEqual(result.data, b"")
# message should have been read into the buffer # message should have been read into the buffer
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
def test_read_many_packets(self): def test_read_many_packets(self):
message = bytes(range(256)) message = bytes(range(256))
header = make_header( header = make_header(
getPlaintext(), getPlaintext(),
COMMON_CID, COMMON_CID,
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH, len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
) )
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message) chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
# message = MESSAGE_TYPE_BYTES + message + checksum # message = MESSAGE_TYPE_BYTES + message + checksum
# first packet is init header + 59 bytes of data # first packet is init header + 59 bytes of data
@ -163,7 +163,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [ packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
cont_header + chunk cont_header + chunk
for chunk in chunks( for chunk in chunks(
message[INIT_MESSAGE_DATA_LENGTH:] + checksum, message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
64 - HEADER_CONT_LENGTH, 64 - HEADER_CONT_LENGTH,
) )
] ]
@ -185,21 +185,21 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(result.data, message) self.assertEqual(result.data, message)
# message should have been read into the buffer ) # message should have been read into the buffer )
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
def test_read_large_message(self): def test_read_large_message(self):
message = b"hello world" message = b"hello world"
header = make_header( header = make_header(
getPlaintext(), getPlaintext(),
COMMON_CID, COMMON_CID,
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH, _MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
) )
packet = ( packet = (
header header
+ MESSAGE_TYPE_BYTES + MESSAGE_TYPE_BYTES
+ message + message
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message) + checksum.compute(header + MESSAGE_TYPE_BYTES + message)
) )
# make sure we fit into one packet, to make this easier # make sure we fit into one packet, to make this easier
@ -225,7 +225,9 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
self.assertEqual(buffer, b"\x00") self.assertEqual(buffer, b"\x00")
def test_write_one_packet(self): def test_write_one_packet(self):
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)) message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message) gen = thp_v1.write_message(self.interface, message)
query = gen.send(None) query = gen.send(None)
@ -234,19 +236,19 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
gen.send(None) gen.send(None)
header = make_header( header = make_header(
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
) )
expected_message = ( expected_message = (
header header
+ MESSAGE_TYPE_BYTES + MESSAGE_TYPE_BYTES
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES) + checksum.compute(header + MESSAGE_TYPE_BYTES)
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH) + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
) )
self.assertTrue(self.interface.data == [expected_message]) self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self): def test_write_multiple_packets(self):
message_payload = bytes(range(256)) message_payload = bytes(range(256))
message = Message( message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
) )
gen = thp_v1.write_message(self.interface, message) gen = thp_v1.write_message(self.interface, message)
@ -254,10 +256,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
header = make_header( header = make_header(
PLAINTEXT_1, PLAINTEXT_1,
COMMON_CID, COMMON_CID,
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH, len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
) )
cont_header = make_cont_header() cont_header = make_cont_header()
checksum = thp_v1._compute_checksum_bytes( chksum = checksum.compute(
header + message.type.to_bytes(2, "big") + message.data header + message.type.to_bytes(2, "big") + message.data
) )
packets = [ packets = [
@ -265,7 +267,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
] + [ ] + [
cont_header + chunk cont_header + chunk
for chunk in chunks( for chunk in chunks(
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum, message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH, thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
) )
] ]
@ -290,7 +292,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
def test_roundtrip(self): def test_roundtrip(self):
message_payload = bytes(range(256)) message_payload = bytes(range(256))
message = Message( message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
) )
gen = thp_v1.write_message(self.interface, message) gen = thp_v1.write_message(self.interface, message)
@ -320,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
message_size = (PACKET_COUNT - 1) * ( message_size = (PACKET_COUNT - 1) * (
thp_v1._REPORT_LENGTH thp_v1._REPORT_LENGTH
- HEADER_CONT_LENGTH - HEADER_CONT_LENGTH
- _CHECKSUM_LENGTH - CHECKSUM_LENGTH
- _MESSAGE_TYPE_LEN - _MESSAGE_TYPE_LEN
) + INIT_MESSAGE_DATA_LENGTH ) + INIT_MESSAGE_DATA_LENGTH