mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-24 13:22:05 +00:00
Channel allocation handling + refactor
This commit is contained in:
parent
84c069cdbf
commit
aa346086c3
6
core/src/all_modules.py
generated
6
core/src/all_modules.py
generated
@ -201,12 +201,14 @@ trezor.wire.context
|
|||||||
import trezor.wire.context
|
import trezor.wire.context
|
||||||
trezor.wire.errors
|
trezor.wire.errors
|
||||||
import trezor.wire.errors
|
import trezor.wire.errors
|
||||||
|
trezor.wire.thp.thp_messages
|
||||||
|
import trezor.wire.thp.thp_messages
|
||||||
trezor.wire.protocol
|
trezor.wire.protocol
|
||||||
import trezor.wire.protocol
|
import trezor.wire.protocol
|
||||||
trezor.wire.protocol_common
|
trezor.wire.protocol_common
|
||||||
import trezor.wire.protocol_common
|
import trezor.wire.protocol_common
|
||||||
trezor.wire.thp_session
|
trezor.wire.thp.thp_session
|
||||||
import trezor.wire.thp_session
|
import trezor.wire.thp.thp_session
|
||||||
trezor.wire.thp_v1
|
trezor.wire.thp_v1
|
||||||
import trezor.wire.thp_v1
|
import trezor.wire.thp_v1
|
||||||
trezor.workflow
|
trezor.workflow
|
||||||
|
@ -89,8 +89,8 @@ if __debug__:
|
|||||||
|
|
||||||
|
|
||||||
async def _handle_single_message(
|
async def _handle_single_message(
|
||||||
ctx: context.Context, msg: protocol_common.Message, use_workflow: bool
|
ctx: context.Context, msg: protocol_common.MessageWithId, use_workflow: bool
|
||||||
) -> protocol_common.Message | None:
|
) -> protocol_common.MessageWithId | None:
|
||||||
"""Handle a message that was loaded from USB by the caller.
|
"""Handle a message that was loaded from USB by the caller.
|
||||||
|
|
||||||
Find the appropriate handler, run it and write its result on the wire. In case
|
Find the appropriate handler, run it and write its result on the wire. In case
|
||||||
@ -206,7 +206,7 @@ async def handle_session(
|
|||||||
ctx_buffer = WIRE_BUFFER
|
ctx_buffer = WIRE_BUFFER
|
||||||
|
|
||||||
ctx = context.Context(iface, ctx_buffer, session_id)
|
ctx = context.Context(iface, ctx_buffer, session_id)
|
||||||
next_msg: protocol_common.Message | None = None
|
next_msg: protocol_common.MessageWithId | None = None
|
||||||
|
|
||||||
if __debug__ and is_debug_session:
|
if __debug__ and is_debug_session:
|
||||||
import apps.debug
|
import apps.debug
|
||||||
|
@ -3,7 +3,7 @@ from micropython import const
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import io, loop, utils
|
from trezor import io, loop, utils
|
||||||
from trezor.wire.protocol_common import Message, WireError
|
from trezor.wire.protocol_common import MessageWithId, WireError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
@ -23,7 +23,7 @@ class CodecError(WireError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
# wait for initial report
|
# wait for initial report
|
||||||
@ -65,7 +65,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
if read_and_throw_away:
|
if read_and_throw_away:
|
||||||
raise CodecError("Message too large")
|
raise CodecError("Message too large")
|
||||||
|
|
||||||
return Message(mtype, mdata)
|
return MessageWithId(mtype, mdata)
|
||||||
|
|
||||||
|
|
||||||
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
async def write_message(iface: WireInterface, mtype: int, mdata: bytes) -> None:
|
||||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
|
|||||||
import trezor.wire.protocol as protocol
|
import trezor.wire.protocol as protocol
|
||||||
from trezor import log, loop, protobuf
|
from trezor import log, loop, protobuf
|
||||||
|
|
||||||
from .protocol_common import Message
|
from .protocol_common import MessageWithId
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
@ -49,7 +49,7 @@ class UnexpectedMessage(Exception):
|
|||||||
should be aborted and a new one started as if `msg` was the first message.
|
should be aborted and a new one started as if `msg` was the first message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, msg: Message) -> None:
|
def __init__(self, msg: MessageWithId) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ class Context:
|
|||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
|
|
||||||
def read_from_wire(self) -> Awaitable[Message]:
|
def read_from_wire(self) -> Awaitable[MessageWithId]:
|
||||||
"""Read a whole message from the wire without parsing it."""
|
"""Read a whole message from the wire without parsing it."""
|
||||||
return protocol.read_message(self.iface, self.buffer)
|
return protocol.read_message(self.iface, self.buffer)
|
||||||
|
|
||||||
@ -177,7 +177,7 @@ class Context:
|
|||||||
msg_session_id = bytearray(self.session_id)
|
msg_session_id = bytearray(self.session_id)
|
||||||
await protocol.write_message(
|
await protocol.write_message(
|
||||||
self.iface,
|
self.iface,
|
||||||
Message(
|
MessageWithId(
|
||||||
message_type=msg.MESSAGE_WIRE_TYPE,
|
message_type=msg.MESSAGE_WIRE_TYPE,
|
||||||
message_data=memoryview(buffer)[:msg_size],
|
message_data=memoryview(buffer)[:msg_size],
|
||||||
session_id=msg_session_id,
|
session_id=msg_session_id,
|
||||||
|
@ -2,21 +2,21 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from trezor import utils
|
from trezor import utils
|
||||||
from trezor.wire import codec_v1, thp_v1
|
from trezor.wire import codec_v1, thp_v1
|
||||||
from trezor.wire.protocol_common import Message
|
from trezor.wire.protocol_common import MessageWithId
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
||||||
if utils.USE_THP:
|
if utils.USE_THP:
|
||||||
return await thp_v1.read_message(iface, buffer)
|
return await thp_v1.read_message(iface, buffer)
|
||||||
return await codec_v1.read_message(iface, buffer)
|
return await codec_v1.read_message(iface, buffer)
|
||||||
|
|
||||||
|
|
||||||
async def write_message(iface: WireInterface, message: Message) -> None:
|
async def write_message(iface: WireInterface, message: MessageWithId) -> None:
|
||||||
if utils.USE_THP:
|
if utils.USE_THP:
|
||||||
await thp_v1.write_message(iface, message)
|
await thp_v1.write_message_with_sync_control(iface, message)
|
||||||
return
|
return
|
||||||
await codec_v1.write_message(iface, message.type, message.data)
|
await codec_v1.write_message(iface, message.type, message.data)
|
||||||
return
|
return
|
||||||
|
@ -1,13 +1,25 @@
|
|||||||
class Message:
|
class Message:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_type: int,
|
||||||
|
message_data: bytes,
|
||||||
|
) -> None:
|
||||||
|
self.type = message_type
|
||||||
|
self.data = message_data
|
||||||
|
|
||||||
|
def to_bytes(self):
|
||||||
|
return self.type.to_bytes(2, "big") + self.data
|
||||||
|
|
||||||
|
|
||||||
|
class MessageWithId(Message):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_type: int,
|
message_type: int,
|
||||||
message_data: bytes,
|
message_data: bytes,
|
||||||
session_id: bytearray | None = None,
|
session_id: bytearray | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.type = message_type
|
|
||||||
self.data = message_data
|
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
|
super().__init__(message_type, message_data)
|
||||||
|
|
||||||
|
|
||||||
class WireError(Exception):
|
class WireError(Exception):
|
||||||
|
22
core/src/trezor/wire/thp/ack_handler.py
Normal file
22
core/src/trezor/wire/thp/ack_handler.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from storage.cache_thp import SessionThpCache
|
||||||
|
|
||||||
|
from . import thp_session as THP
|
||||||
|
|
||||||
|
|
||||||
|
def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
|
||||||
|
|
||||||
|
if _ack_is_not_expected(session):
|
||||||
|
return
|
||||||
|
if _ack_has_incorrect_sync_bit(session, sync_bit):
|
||||||
|
return
|
||||||
|
|
||||||
|
# ACK is expected and it has correct sync bit
|
||||||
|
THP.sync_set_can_send_message(session, True)
|
||||||
|
|
||||||
|
|
||||||
|
def _ack_is_not_expected(session: SessionThpCache) -> bool:
|
||||||
|
return THP.sync_can_send_message(session)
|
||||||
|
|
||||||
|
|
||||||
|
def _ack_has_incorrect_sync_bit(session: SessionThpCache, sync_bit: int) -> bool:
|
||||||
|
return THP.sync_get_send_bit(session) != sync_bit
|
15
core/src/trezor/wire/thp/checksum.py
Normal file
15
core/src/trezor/wire/thp/checksum.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from micropython import const
|
||||||
|
|
||||||
|
from trezor import utils
|
||||||
|
from trezor.crypto import crc
|
||||||
|
|
||||||
|
CHECKSUM_LENGTH = const(4)
|
||||||
|
|
||||||
|
|
||||||
|
def compute(data: bytes | utils.BufferType) -> bytes:
|
||||||
|
return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
|
||||||
|
data_checksum = compute(data)
|
||||||
|
return checksum == data_checksum
|
71
core/src/trezor/wire/thp/thp_messages.py
Normal file
71
core/src/trezor/wire/thp/thp_messages.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import ustruct
|
||||||
|
|
||||||
|
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||||
|
|
||||||
|
from ..protocol_common import Message
|
||||||
|
|
||||||
|
CONTINUATION_PACKET = 0x80
|
||||||
|
_ERROR = 0x41
|
||||||
|
_CHANNEL_ALLOCATION_RES = 0x40
|
||||||
|
|
||||||
|
|
||||||
|
class InitHeader:
|
||||||
|
format_str = ">BHH"
|
||||||
|
|
||||||
|
def __init__(self, ctrl_byte, cid, length) -> None:
|
||||||
|
self.ctrl_byte = ctrl_byte
|
||||||
|
self.cid = cid
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def to_bytes(self) -> bytes:
|
||||||
|
return ustruct.pack(
|
||||||
|
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
|
||||||
|
ustruct.pack_into(
|
||||||
|
InitHeader.format_str,
|
||||||
|
buffer,
|
||||||
|
buffer_offset,
|
||||||
|
self.ctrl_byte,
|
||||||
|
self.cid,
|
||||||
|
self.length,
|
||||||
|
)
|
||||||
|
|
||||||
|
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
|
||||||
|
ustruct.pack_into(">BH", buffer, buffer_offset, CONTINUATION_PACKET, self.cid)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_error_header(cls, cid, length):
|
||||||
|
return cls(_ERROR, cid, length)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_channel_allocation_response_header(cls, length):
|
||||||
|
return cls(_CHANNEL_ALLOCATION_RES, BROADCAST_CHANNEL_ID, length)
|
||||||
|
|
||||||
|
|
||||||
|
class InterruptingInitPacket:
|
||||||
|
def __init__(self, report: bytes) -> None:
|
||||||
|
self.initReport = report
|
||||||
|
|
||||||
|
|
||||||
|
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
|
||||||
|
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
|
||||||
|
)
|
||||||
|
|
||||||
|
_ERROR_UNALLOCATED_SESSION = (
|
||||||
|
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_properties() -> Message:
|
||||||
|
return Message(1000, _ENCODED_PROTOBUF_DEVICE_PROPERTIES)
|
||||||
|
|
||||||
|
|
||||||
|
def get_channel_allocation_response(nonce: bytes, new_cid: int) -> bytes:
|
||||||
|
props_msg = get_device_properties()
|
||||||
|
return ustruct.pack(">8sH", nonce, new_cid) + props_msg.to_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
def get_error_unallocated_channel() -> bytes:
|
||||||
|
return _ERROR_UNALLOCATED_SESSION
|
@ -2,74 +2,35 @@ import ustruct
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import trezor.wire.thp_session as THP
|
|
||||||
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
|
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
|
||||||
from trezor import io, loop, utils
|
from trezor import io, loop, utils
|
||||||
from trezor.crypto import crc
|
|
||||||
from trezor.wire.protocol_common import Message
|
from .protocol_common import MessageWithId
|
||||||
from trezor.wire.thp_session import SessionState, ThpError
|
from .thp import ack_handler, checksum, thp_messages
|
||||||
|
from .thp import thp_session as THP
|
||||||
|
from .thp.checksum import CHECKSUM_LENGTH
|
||||||
|
from .thp.thp_messages import CONTINUATION_PACKET, InitHeader, InterruptingInitPacket
|
||||||
|
from .thp.thp_session import SessionState, ThpError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
|
||||||
_MAX_PAYLOAD_LEN = const(60000)
|
_MAX_PAYLOAD_LEN = const(60000)
|
||||||
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
|
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
|
||||||
_CHECKSUM_LENGTH = const(4)
|
|
||||||
_CHANNEL_ALLOCATION_REQ = 0x40
|
_CHANNEL_ALLOCATION_REQ = 0x40
|
||||||
_CHANNEL_ALLOCATION_RES = 0x40
|
|
||||||
_ERROR = 0x41
|
|
||||||
_CONTINUATION_PACKET = 0x80
|
|
||||||
_ACK_MESSAGE = 0x20
|
_ACK_MESSAGE = 0x20
|
||||||
_HANDSHAKE_INIT = 0x00
|
_HANDSHAKE_INIT = 0x00
|
||||||
_PLAINTEXT = 0x01
|
_PLAINTEXT = 0x01
|
||||||
ENCRYPTED_TRANSPORT = 0x02
|
ENCRYPTED_TRANSPORT = 0x02
|
||||||
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
|
|
||||||
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
|
|
||||||
)
|
|
||||||
_UNALLOCATED_SESSION_ERROR = (
|
|
||||||
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
|
|
||||||
)
|
|
||||||
|
|
||||||
_REPORT_LENGTH = const(64)
|
_REPORT_LENGTH = const(64)
|
||||||
_REPORT_INIT_DATA_OFFSET = const(5)
|
_REPORT_INIT_DATA_OFFSET = const(5)
|
||||||
_REPORT_CONT_DATA_OFFSET = const(3)
|
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||||
|
|
||||||
|
|
||||||
class InitHeader:
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
||||||
format_str = ">BHH"
|
|
||||||
|
|
||||||
def __init__(self, ctrl_byte, cid, length) -> None:
|
|
||||||
self.ctrl_byte = ctrl_byte
|
|
||||||
self.cid = cid
|
|
||||||
self.length = length
|
|
||||||
|
|
||||||
def to_bytes(self) -> bytes:
|
|
||||||
return ustruct.pack(
|
|
||||||
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
|
|
||||||
)
|
|
||||||
|
|
||||||
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
|
|
||||||
ustruct.pack_into(
|
|
||||||
InitHeader.format_str,
|
|
||||||
buffer,
|
|
||||||
buffer_offset,
|
|
||||||
self.ctrl_byte,
|
|
||||||
self.cid,
|
|
||||||
self.length,
|
|
||||||
)
|
|
||||||
|
|
||||||
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
|
|
||||||
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
|
|
||||||
|
|
||||||
|
|
||||||
class InterruptingInitPacket:
|
|
||||||
def __init__(self, report: bytes) -> None:
|
|
||||||
self.initReport = report
|
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
|
||||||
msg = await read_message_or_init_packet(iface, buffer)
|
msg = await read_message_or_init_packet(iface, buffer)
|
||||||
while type(msg) is not Message:
|
while type(msg) is not MessageWithId:
|
||||||
if isinstance(msg, InterruptingInitPacket):
|
if isinstance(msg, InterruptingInitPacket):
|
||||||
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
|
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
|
||||||
else:
|
else:
|
||||||
@ -79,7 +40,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
|
|
||||||
async def read_message_or_init_packet(
|
async def read_message_or_init_packet(
|
||||||
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
|
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
|
||||||
) -> Message | InterruptingInitPacket:
|
) -> MessageWithId | InterruptingInitPacket:
|
||||||
report = firstReport
|
report = firstReport
|
||||||
while True:
|
while True:
|
||||||
# Wait for an initial report
|
# Wait for an initial report
|
||||||
@ -89,7 +50,7 @@ async def read_message_or_init_packet(
|
|||||||
raise ThpError("Reading failed unexpectedly, report is None.")
|
raise ThpError("Reading failed unexpectedly, report is None.")
|
||||||
|
|
||||||
# Channel multiplexing
|
# Channel multiplexing
|
||||||
ctrl_byte, cid = ustruct.unpack(">BH", report)
|
ctrl_byte, cid, payload_length = ustruct.unpack(">BHH", report)
|
||||||
|
|
||||||
if cid == BROADCAST_CHANNEL_ID:
|
if cid == BROADCAST_CHANNEL_ID:
|
||||||
await _handle_broadcast(iface, ctrl_byte, report) # TODO await
|
await _handle_broadcast(iface, ctrl_byte, report) # TODO await
|
||||||
@ -104,7 +65,6 @@ async def read_message_or_init_packet(
|
|||||||
report = None
|
report = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload_length = ustruct.unpack(">H", report[3:])[0]
|
|
||||||
payload = _get_buffer_for_payload(payload_length, buffer)
|
payload = _get_buffer_for_payload(payload_length, buffer)
|
||||||
header = InitHeader(ctrl_byte, cid, payload_length)
|
header = InitHeader(ctrl_byte, cid, payload_length)
|
||||||
|
|
||||||
@ -114,7 +74,7 @@ async def read_message_or_init_packet(
|
|||||||
return interruptingPacket
|
return interruptingPacket
|
||||||
|
|
||||||
# Check CRC
|
# Check CRC
|
||||||
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||||
# checksum is not valid -> ignore message
|
# checksum is not valid -> ignore message
|
||||||
report = None
|
report = None
|
||||||
continue
|
continue
|
||||||
@ -141,7 +101,7 @@ async def read_message_or_init_packet(
|
|||||||
|
|
||||||
# 1: Handle ACKs
|
# 1: Handle ACKs
|
||||||
if _is_ctrl_byte_ack(ctrl_byte):
|
if _is_ctrl_byte_ack(ctrl_byte):
|
||||||
_handle_received_ACK(session, sync_bit)
|
ack_handler.handle_received_ACK(session, sync_bit)
|
||||||
report = None
|
report = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -215,8 +175,19 @@ async def _buffer_received_data(
|
|||||||
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
|
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_message_with_sync_control(
|
||||||
|
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
||||||
|
) -> None:
|
||||||
|
session = THP.get_session_from_id(message.session_id)
|
||||||
|
if session is None:
|
||||||
|
raise ThpError("Invalid session")
|
||||||
|
if (not THP.sync_can_send_message(session)) and (not is_retransmission):
|
||||||
|
raise ThpError("Cannot send another message before ACK is received.")
|
||||||
|
await write_message(iface, message, is_retransmission)
|
||||||
|
|
||||||
|
|
||||||
async def write_message(
|
async def write_message(
|
||||||
iface: WireInterface, message: Message, is_retransmission: bool = False
|
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
session = THP.get_session_from_id(message.session_id)
|
session = THP.get_session_from_id(message.session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
@ -245,9 +216,9 @@ async def write_message(
|
|||||||
ctrl_byte, 1 - THP.sync_get_send_bit(session)
|
ctrl_byte, 1 - THP.sync_get_send_bit(session)
|
||||||
)
|
)
|
||||||
|
|
||||||
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
|
header = InitHeader(ctrl_byte, cid, payload_length + CHECKSUM_LENGTH)
|
||||||
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
|
chksum = checksum.compute(header.to_bytes() + payload)
|
||||||
await write_to_wire(iface, header, payload + checksum)
|
await write_to_wire(iface, header, payload + chksum)
|
||||||
# TODO set timeout for retransmission
|
# TODO set timeout for retransmission
|
||||||
|
|
||||||
|
|
||||||
@ -283,7 +254,9 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None:
|
async def _handle_broadcast(
|
||||||
|
iface: WireInterface, ctrl_byte, report
|
||||||
|
) -> MessageWithId | None:
|
||||||
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
||||||
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
||||||
|
|
||||||
@ -291,67 +264,51 @@ async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message
|
|||||||
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
|
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
|
||||||
|
|
||||||
payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
|
payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
|
||||||
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||||
raise ThpError("Checksum is not valid")
|
raise ThpError("Checksum is not valid")
|
||||||
|
|
||||||
channel_id = _get_new_channel_id()
|
channel_id = _get_new_channel_id()
|
||||||
THP.create_new_unauthenticated_session(iface, channel_id)
|
THP.create_new_unauthenticated_session(iface, channel_id)
|
||||||
|
|
||||||
response_data = (
|
response_data = thp_messages.get_channel_allocation_response(nonce, channel_id)
|
||||||
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
|
response_header = InitHeader.get_channel_allocation_response_header(
|
||||||
|
len(response_data) + CHECKSUM_LENGTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_header = InitHeader(
|
chksum = checksum.compute(response_header.to_bytes() + response_data)
|
||||||
_CHANNEL_ALLOCATION_RES,
|
await write_to_wire(iface, response_header, response_data + chksum)
|
||||||
BROADCAST_CHANNEL_ID,
|
|
||||||
len(response_data) + _CHECKSUM_LENGTH,
|
|
||||||
)
|
|
||||||
|
|
||||||
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
|
|
||||||
await write_to_wire(iface, response_header, response_data + checksum)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
|
async def _handle_allocated(
|
||||||
|
ctrl_byte, session: SessionThpCache, payload
|
||||||
|
) -> MessageWithId:
|
||||||
# Parameters session and ctrl_byte will be used to determine if the
|
# Parameters session and ctrl_byte will be used to determine if the
|
||||||
# communication should be encrypted or not
|
# communication should be encrypted or not
|
||||||
|
|
||||||
message_type = ustruct.unpack(">H", payload)[0]
|
message_type = ustruct.unpack(">H", payload)[0]
|
||||||
|
|
||||||
# trim message type and checksum from payload
|
# trim message type and checksum from payload
|
||||||
message_data = payload[2:-_CHECKSUM_LENGTH]
|
message_data = payload[2:-CHECKSUM_LENGTH]
|
||||||
return Message(message_type, message_data, session.session_id)
|
return MessageWithId(message_type, message_data, session.session_id)
|
||||||
|
|
||||||
|
|
||||||
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
|
async def _handle_unallocated(iface, cid) -> MessageWithId | None:
|
||||||
# No ACKs expected
|
data = thp_messages.get_error_unallocated_channel()
|
||||||
if THP.sync_can_send_message(session):
|
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
|
||||||
return
|
chksum = checksum.compute(header.to_bytes() + data)
|
||||||
|
await write_to_wire(iface, header, data + chksum)
|
||||||
# ACK has incorrect sync bit
|
|
||||||
if THP.sync_get_send_bit(session) != sync_bit:
|
|
||||||
return
|
|
||||||
|
|
||||||
# ACK is expected and it has correct sync bit
|
|
||||||
THP.sync_set_can_send_message(session, True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_unallocated(iface, cid) -> Message | None:
|
|
||||||
data = _UNALLOCATED_SESSION_ERROR
|
|
||||||
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
|
|
||||||
checksum = _compute_checksum_bytes(header.to_bytes() + data)
|
|
||||||
await write_to_wire(iface, header, data + checksum)
|
|
||||||
|
|
||||||
|
|
||||||
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
|
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
|
||||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
||||||
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
|
header = InitHeader(ctrl_byte, cid, CHECKSUM_LENGTH)
|
||||||
checksum = _compute_checksum_bytes(header.to_bytes())
|
chksum = checksum.compute(header.to_bytes())
|
||||||
await write_to_wire(iface, header, checksum)
|
await write_to_wire(iface, header, chksum)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_unexpected_sync_bit(
|
async def _handle_unexpected_sync_bit(
|
||||||
iface: WireInterface, cid: int, sync_bit: int
|
iface: WireInterface, cid: int, sync_bit: int
|
||||||
) -> Message | None:
|
) -> MessageWithId | None:
|
||||||
await _sendAck(iface, cid, sync_bit)
|
await _sendAck(iface, cid, sync_bit)
|
||||||
|
|
||||||
# TODO handle cancelation messages and messages on allocated channels without synchronization
|
# TODO handle cancelation messages and messages on allocated channels without synchronization
|
||||||
@ -362,13 +319,8 @@ def _get_new_channel_id() -> int:
|
|||||||
return THP.get_next_channel_id()
|
return THP.get_next_channel_id()
|
||||||
|
|
||||||
|
|
||||||
def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
|
|
||||||
data_checksum = _compute_checksum_bytes(data)
|
|
||||||
return checksum == data_checksum
|
|
||||||
|
|
||||||
|
|
||||||
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
||||||
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
|
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||||
|
|
||||||
|
|
||||||
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
|
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
|
||||||
@ -381,7 +333,3 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
|
|||||||
if sync_bit == 1:
|
if sync_bit == 1:
|
||||||
return ctrl_byte | 0x10
|
return ctrl_byte | 0x10
|
||||||
raise ThpError("Unexpected synchronization bit")
|
raise ThpError("Unexpected synchronization bit")
|
||||||
|
|
||||||
|
|
||||||
def _compute_checksum_bytes(data: bytes | utils.BufferType) -> bytes:
|
|
||||||
return crc.crc32(data).to_bytes(4, "big")
|
|
||||||
|
@ -6,11 +6,11 @@ from trezor import io, utils
|
|||||||
from trezor.loop import wait
|
from trezor.loop import wait
|
||||||
from trezor.utils import chunks
|
from trezor.utils import chunks
|
||||||
from trezor.wire import thp_v1
|
from trezor.wire import thp_v1
|
||||||
from trezor.wire.thp_v1 import _CHECKSUM_LENGTH, BROADCAST_CHANNEL_ID
|
from trezor.wire.thp_v1 import BROADCAST_CHANNEL_ID
|
||||||
from trezor.wire.protocol_common import Message
|
from trezor.wire.protocol_common import MessageWithId
|
||||||
import trezor.wire.thp_session as THP
|
import trezor.wire.thp.thp_session as THP
|
||||||
|
from trezor.wire.thp import checksum
|
||||||
from micropython import const
|
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
|
||||||
|
|
||||||
|
|
||||||
class MockHID:
|
class MockHID:
|
||||||
@ -123,10 +123,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
# zero length message - just a header
|
# zero length message - just a header
|
||||||
PLAINTEXT = getPlaintext()
|
PLAINTEXT = getPlaintext()
|
||||||
header = make_header(
|
header = make_header(
|
||||||
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
||||||
)
|
)
|
||||||
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES)
|
||||||
message = header + MESSAGE_TYPE_BYTES + checksum
|
message = header + MESSAGE_TYPE_BYTES + chksum
|
||||||
|
|
||||||
buffer = bytearray(64)
|
buffer = bytearray(64)
|
||||||
gen = thp_v1.read_message(self.interface, buffer)
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
@ -145,16 +145,16 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
self.assertEqual(result.data, b"")
|
self.assertEqual(result.data, b"")
|
||||||
|
|
||||||
# message should have been read into the buffer
|
# message should have been read into the buffer
|
||||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + checksum + b"\x00" * 58)
|
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58)
|
||||||
|
|
||||||
def test_read_many_packets(self):
|
def test_read_many_packets(self):
|
||||||
message = bytes(range(256))
|
message = bytes(range(256))
|
||||||
header = make_header(
|
header = make_header(
|
||||||
getPlaintext(),
|
getPlaintext(),
|
||||||
COMMON_CID,
|
COMMON_CID,
|
||||||
len(message) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
|
||||||
)
|
)
|
||||||
checksum = thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message)
|
||||||
# message = MESSAGE_TYPE_BYTES + message + checksum
|
# message = MESSAGE_TYPE_BYTES + message + checksum
|
||||||
|
|
||||||
# first packet is init header + 59 bytes of data
|
# first packet is init header + 59 bytes of data
|
||||||
@ -163,7 +163,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
|
packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [
|
||||||
cont_header + chunk
|
cont_header + chunk
|
||||||
for chunk in chunks(
|
for chunk in chunks(
|
||||||
message[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
message[INIT_MESSAGE_DATA_LENGTH:] + chksum,
|
||||||
64 - HEADER_CONT_LENGTH,
|
64 - HEADER_CONT_LENGTH,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -185,21 +185,21 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
self.assertEqual(result.data, message)
|
self.assertEqual(result.data, message)
|
||||||
|
|
||||||
# message should have been read into the buffer )
|
# message should have been read into the buffer )
|
||||||
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + checksum)
|
self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum)
|
||||||
|
|
||||||
def test_read_large_message(self):
|
def test_read_large_message(self):
|
||||||
message = b"hello world"
|
message = b"hello world"
|
||||||
header = make_header(
|
header = make_header(
|
||||||
getPlaintext(),
|
getPlaintext(),
|
||||||
COMMON_CID,
|
COMMON_CID,
|
||||||
_MESSAGE_TYPE_LEN + len(message) + _CHECKSUM_LENGTH,
|
_MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
packet = (
|
packet = (
|
||||||
header
|
header
|
||||||
+ MESSAGE_TYPE_BYTES
|
+ MESSAGE_TYPE_BYTES
|
||||||
+ message
|
+ message
|
||||||
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES + message)
|
+ checksum.compute(header + MESSAGE_TYPE_BYTES + message)
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure we fit into one packet, to make this easier
|
# make sure we fit into one packet, to make this easier
|
||||||
@ -225,7 +225,9 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
self.assertEqual(buffer, b"\x00")
|
self.assertEqual(buffer, b"\x00")
|
||||||
|
|
||||||
def test_write_one_packet(self):
|
def test_write_one_packet(self):
|
||||||
message = Message(MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID))
|
message = MessageWithId(
|
||||||
|
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
|
||||||
|
)
|
||||||
gen = thp_v1.write_message(self.interface, message)
|
gen = thp_v1.write_message(self.interface, message)
|
||||||
|
|
||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
@ -234,19 +236,19 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
gen.send(None)
|
gen.send(None)
|
||||||
|
|
||||||
header = make_header(
|
header = make_header(
|
||||||
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH
|
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
||||||
)
|
)
|
||||||
expected_message = (
|
expected_message = (
|
||||||
header
|
header
|
||||||
+ MESSAGE_TYPE_BYTES
|
+ MESSAGE_TYPE_BYTES
|
||||||
+ thp_v1._compute_checksum_bytes(header + MESSAGE_TYPE_BYTES)
|
+ checksum.compute(header + MESSAGE_TYPE_BYTES)
|
||||||
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - _CHECKSUM_LENGTH)
|
+ b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH)
|
||||||
)
|
)
|
||||||
self.assertTrue(self.interface.data == [expected_message])
|
self.assertTrue(self.interface.data == [expected_message])
|
||||||
|
|
||||||
def test_write_multiple_packets(self):
|
def test_write_multiple_packets(self):
|
||||||
message_payload = bytes(range(256))
|
message_payload = bytes(range(256))
|
||||||
message = Message(
|
message = MessageWithId(
|
||||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||||
)
|
)
|
||||||
gen = thp_v1.write_message(self.interface, message)
|
gen = thp_v1.write_message(self.interface, message)
|
||||||
@ -254,10 +256,10 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
header = make_header(
|
header = make_header(
|
||||||
PLAINTEXT_1,
|
PLAINTEXT_1,
|
||||||
COMMON_CID,
|
COMMON_CID,
|
||||||
len(message.data) + _MESSAGE_TYPE_LEN + _CHECKSUM_LENGTH,
|
len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH,
|
||||||
)
|
)
|
||||||
cont_header = make_cont_header()
|
cont_header = make_cont_header()
|
||||||
checksum = thp_v1._compute_checksum_bytes(
|
chksum = checksum.compute(
|
||||||
header + message.type.to_bytes(2, "big") + message.data
|
header + message.type.to_bytes(2, "big") + message.data
|
||||||
)
|
)
|
||||||
packets = [
|
packets = [
|
||||||
@ -265,7 +267,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
] + [
|
] + [
|
||||||
cont_header + chunk
|
cont_header + chunk
|
||||||
for chunk in chunks(
|
for chunk in chunks(
|
||||||
message.data[INIT_MESSAGE_DATA_LENGTH:] + checksum,
|
message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum,
|
||||||
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
|
thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -290,7 +292,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
|
|
||||||
def test_roundtrip(self):
|
def test_roundtrip(self):
|
||||||
message_payload = bytes(range(256))
|
message_payload = bytes(range(256))
|
||||||
message = Message(
|
message = MessageWithId(
|
||||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||||
)
|
)
|
||||||
gen = thp_v1.write_message(self.interface, message)
|
gen = thp_v1.write_message(self.interface, message)
|
||||||
@ -320,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
message_size = (PACKET_COUNT - 1) * (
|
message_size = (PACKET_COUNT - 1) * (
|
||||||
thp_v1._REPORT_LENGTH
|
thp_v1._REPORT_LENGTH
|
||||||
- HEADER_CONT_LENGTH
|
- HEADER_CONT_LENGTH
|
||||||
- _CHECKSUM_LENGTH
|
- CHECKSUM_LENGTH
|
||||||
- _MESSAGE_TYPE_LEN
|
- _MESSAGE_TYPE_LEN
|
||||||
) + INIT_MESSAGE_DATA_LENGTH
|
) + INIT_MESSAGE_DATA_LENGTH
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user