mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
refactor(core): clean thp code
This commit is contained in:
parent
209e548ab5
commit
3f590bc11d
@ -103,13 +103,6 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
|
||||
|
||||
thp_v1.set_buffer(ctx_buffer)
|
||||
|
||||
if __debug__ and is_debug_session:
|
||||
import apps.debug
|
||||
|
||||
print(apps.debug.DEBUG_CONTEXT) # TODO remove
|
||||
|
||||
# TODO add debug context or smth to apps.debug
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
# roll back and un-import any others.
|
||||
modules = utils.unimport_begin()
|
||||
|
@ -2,12 +2,11 @@ import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor import io, log, loop, utils
|
||||
|
||||
from .protocol_common import MessageWithId
|
||||
from .thp import ChannelState, ack_handler, checksum, thp_messages
|
||||
from .thp import thp_session as THP
|
||||
from .thp import ChannelState, checksum, thp_messages
|
||||
from .thp.channel import (
|
||||
CONT_DATA_OFFSET,
|
||||
INIT_DATA_OFFSET,
|
||||
@ -17,22 +16,14 @@ from .thp.channel import (
|
||||
load_cached_channels,
|
||||
)
|
||||
from .thp.checksum import CHECKSUM_LENGTH
|
||||
from .thp.thp_messages import (
|
||||
CODEC_V1,
|
||||
CONTINUATION_PACKET,
|
||||
ENCRYPTED_TRANSPORT,
|
||||
InitHeader,
|
||||
InterruptingInitPacket,
|
||||
)
|
||||
from .thp.thp_session import SessionState, ThpError
|
||||
from .thp.thp_messages import CODEC_V1, InitHeader
|
||||
from .thp.thp_session import ThpError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
||||
|
||||
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
|
||||
_CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_ACK_MESSAGE = 0x20
|
||||
_PLAINTEXT = 0x01
|
||||
|
||||
|
||||
_BUFFER: bytearray
|
||||
@ -85,220 +76,31 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
# TODO add cleaning sequence if no workflow/channel is active (or some condition like that)
|
||||
|
||||
|
||||
async def deprecated_read_message(
|
||||
iface: WireInterface, buffer: utils.BufferType
|
||||
) -> MessageWithId:
|
||||
msg = await deprecated_read_message_or_init_packet(iface, buffer)
|
||||
while type(msg) is not MessageWithId:
|
||||
if isinstance(msg, InterruptingInitPacket):
|
||||
msg = await deprecated_read_message_or_init_packet(
|
||||
iface, buffer, msg.initReport
|
||||
)
|
||||
else:
|
||||
raise ThpError("Unexpected output of read_message_or_init_packet:")
|
||||
return msg
|
||||
|
||||
|
||||
async def deprecated_read_message_or_init_packet(
|
||||
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
|
||||
) -> MessageWithId | InterruptingInitPacket:
|
||||
report = firstReport
|
||||
while True:
|
||||
# Wait for an initial report
|
||||
if report is None:
|
||||
report = await loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
if report is None:
|
||||
raise ThpError("Reading failed unexpectedly, report is None.")
|
||||
|
||||
# Channel multiplexing
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", report)
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
await _handle_broadcast(iface, ctrl_byte, report)
|
||||
report = None
|
||||
continue
|
||||
|
||||
# We allow for only one message to be read simultaneously. We do not
|
||||
# support reading multiple messages with interleaven packets - with
|
||||
# the sole exception of cid_request which can be handled independently.
|
||||
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||
# continuation packet is not expected - ignore
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received unexpected continuation packet")
|
||||
report = None
|
||||
continue
|
||||
payload_length = ustruct.unpack(">H", report[3:])[0]
|
||||
payload = _get_buffer_for_payload(payload_length, buffer)
|
||||
header = InitHeader(ctrl_byte, cid, payload_length)
|
||||
|
||||
# buffer the received data
|
||||
interruptingPacket = await _deprecated_buffer_received_data(
|
||||
payload, header, iface, report
|
||||
)
|
||||
if interruptingPacket is not None:
|
||||
return interruptingPacket
|
||||
|
||||
# Check CRC
|
||||
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||
# checksum is not valid -> ignore message
|
||||
report = None
|
||||
continue
|
||||
|
||||
session = THP.get_session(iface, cid)
|
||||
session_state = THP.get_state(session)
|
||||
|
||||
# Handle message on unallocated channel
|
||||
if session_state == SessionState.UNALLOCATED:
|
||||
message = await _handle_unallocated(iface, cid)
|
||||
# unallocated should not return regular message, TODO, but it might change
|
||||
if __debug__:
|
||||
log.debug(__name__, "Channel with id: %d in UNALLOCATED", cid)
|
||||
if message is not None:
|
||||
return message
|
||||
report = None
|
||||
continue
|
||||
|
||||
if session is None:
|
||||
raise ThpError("Invalid session!")
|
||||
|
||||
# Note: In the Host, the UNALLOCATED_CHANNEL error should be handled here
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
|
||||
# 1: Handle ACKs
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
ack_handler.handle_received_ACK(session, sync_bit)
|
||||
report = None
|
||||
continue
|
||||
|
||||
# 2: Handle message with unexpected synchronization bit
|
||||
if sync_bit != THP.sync_get_receive_expected_bit(session):
|
||||
message = await _handle_unexpected_sync_bit(iface, cid, sync_bit)
|
||||
# unsynchronized messages should not return regular message, TODO,
|
||||
# but it might change with the cancelation message
|
||||
if message is not None:
|
||||
return message
|
||||
report = None
|
||||
continue
|
||||
|
||||
# 3: Send ACK in response
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
||||
cid,
|
||||
sync_bit,
|
||||
)
|
||||
await _sendAck(iface, cid, sync_bit)
|
||||
THP.sync_set_receive_expected_bit(session, 1 - sync_bit)
|
||||
|
||||
return await _handle_allocated(ctrl_byte, session, payload)
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
return _try_allocate_new_buffer(payload_length)
|
||||
return _reuse_existing_buffer(payload_length, existing_buffer)
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
|
||||
def _try_allocate_new_buffer(payload_length: int) -> utils.BufferType:
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
|
||||
def _reuse_existing_buffer(
|
||||
payload_length: int, existing_buffer: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
async def _deprecated_buffer_received_data(
|
||||
payload: utils.BufferType, header: InitHeader, iface, report
|
||||
) -> None | InterruptingInitPacket:
|
||||
# buffer the initial data
|
||||
nread = utils.memcpy(payload, 0, report, INIT_DATA_OFFSET)
|
||||
while nread < header.length:
|
||||
# wait for continuation report
|
||||
report = await loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
# channel multiplexing
|
||||
cont_ctrl_byte, cont_cid = ustruct.unpack(">BH", report)
|
||||
|
||||
# handle broadcast - allows the reading process
|
||||
# to survive interruption by broadcast
|
||||
if cont_cid == BROADCAST_CHANNEL_ID:
|
||||
await _handle_broadcast(iface, cont_ctrl_byte, report)
|
||||
continue
|
||||
|
||||
# handle unexpected initiation packet
|
||||
if not _is_ctrl_byte_continuation(cont_ctrl_byte):
|
||||
# TODO possibly add timeout - allow interruption only after a long time
|
||||
return InterruptingInitPacket(report)
|
||||
|
||||
# ignore continuation packets on different channels
|
||||
if cont_cid != header.cid:
|
||||
continue
|
||||
|
||||
# buffer the continuation data
|
||||
nread += utils.memcpy(payload, nread, report, CONT_DATA_OFFSET)
|
||||
|
||||
|
||||
async def write_message_with_sync_control(
|
||||
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
||||
) -> None:
|
||||
session = THP.get_session_from_id(message.session_id)
|
||||
if session is None:
|
||||
raise ThpError("Invalid session")
|
||||
if (not THP.sync_can_send_message(session)) and (not is_retransmission):
|
||||
raise ThpError("Cannot send another message before ACK is received.")
|
||||
await write_message(iface, message, is_retransmission)
|
||||
|
||||
|
||||
async def write_message(
|
||||
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
||||
) -> None:
|
||||
session = THP.get_session_from_id(message.session_id)
|
||||
if session is None:
|
||||
raise ThpError("Invalid session")
|
||||
|
||||
cid = THP.get_cid(session)
|
||||
payload = message.type.to_bytes(2, "big") + message.data
|
||||
payload_length = len(payload)
|
||||
|
||||
if THP.get_state(session) == SessionState.INITIALIZED:
|
||||
# write message in plaintext, TODO check if it is allowed
|
||||
ctrl_byte = _PLAINTEXT
|
||||
elif THP.get_state(session) == SessionState.APP_TRAFFIC:
|
||||
ctrl_byte = ENCRYPTED_TRANSPORT
|
||||
else:
|
||||
raise ThpError("Session in not implemented state" + str(THP.get_state(session)))
|
||||
|
||||
if not is_retransmission:
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(
|
||||
ctrl_byte, THP.sync_get_send_bit(session)
|
||||
)
|
||||
THP.sync_set_send_bit_to_opposite(session)
|
||||
else:
|
||||
# retransmission must have the same sync bit as the previously sent message
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(
|
||||
ctrl_byte, 1 - THP.sync_get_send_bit(session)
|
||||
)
|
||||
|
||||
header = InitHeader(ctrl_byte, cid, payload_length + CHECKSUM_LENGTH)
|
||||
chksum = checksum.compute(header.to_bytes() + payload)
|
||||
if __debug__ and message.session_id is not None:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing message with type %d to a session %d",
|
||||
message.type,
|
||||
int.from_bytes(message.session_id, "big"),
|
||||
)
|
||||
await write_to_wire(iface, header, payload + chksum)
|
||||
# TODO set timeout for retransmission
|
||||
|
||||
|
||||
async def write_to_wire(
|
||||
iface: WireInterface, header: InitHeader, payload: bytes
|
||||
) -> None:
|
||||
@ -364,21 +166,6 @@ async def _handle_broadcast(
|
||||
await write_to_wire(iface, response_header, response_data + chksum)
|
||||
|
||||
|
||||
async def _handle_allocated(
|
||||
ctrl_byte, session: SessionThpCache, payload
|
||||
) -> MessageWithId:
|
||||
# Parameters session and ctrl_byte will be used to determine if the
|
||||
# communication should be encrypted or not
|
||||
|
||||
message_type = ustruct.unpack(">H", payload)[0]
|
||||
|
||||
# trim message type and checksum from payload
|
||||
message_data = payload[2:-CHECKSUM_LENGTH]
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received valid message with type %d", message_type)
|
||||
return MessageWithId(message_type, message_data, session.session_id)
|
||||
|
||||
|
||||
async def _handle_unallocated(iface, cid) -> MessageWithId | None:
|
||||
data = thp_messages.get_error_unallocated_channel()
|
||||
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
|
||||
@ -386,35 +173,13 @@ async def _handle_unallocated(iface, cid) -> MessageWithId | None:
|
||||
await write_to_wire(iface, header, data + chksum)
|
||||
|
||||
|
||||
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
|
||||
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
||||
header = InitHeader(ctrl_byte, cid, CHECKSUM_LENGTH)
|
||||
chksum = checksum.compute(header.to_bytes())
|
||||
await write_to_wire(iface, header, chksum)
|
||||
async def deprecated_read_message(
|
||||
iface: WireInterface, buffer: utils.BufferType
|
||||
) -> MessageWithId:
|
||||
return MessageWithId(-1, b"\x00")
|
||||
|
||||
|
||||
async def _handle_unexpected_sync_bit(
|
||||
iface: WireInterface, cid: int, sync_bit: int
|
||||
) -> MessageWithId | None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received message has unexpected synchronization bit")
|
||||
await _sendAck(iface, cid, sync_bit)
|
||||
|
||||
# TODO handle cancelation messages and messages on allocated channels without synchronization
|
||||
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
|
||||
|
||||
|
||||
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
||||
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||
|
||||
|
||||
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
|
||||
return ctrl_byte & 0x20 == _ACK_MESSAGE
|
||||
|
||||
|
||||
def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if sync_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise ThpError("Unexpected synchronization bit")
|
||||
async def deprecated_write_message(
|
||||
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
||||
) -> None:
|
||||
pass
|
||||
|
@ -223,7 +223,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
message = MessageWithId(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
gen = thp_v1.deprecated_write_message(self.interface, message)
|
||||
# exhaust the iterator:
|
||||
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||
for query in gen:
|
||||
@ -248,7 +248,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
message = MessageWithId(
|
||||
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
gen = thp_v1.deprecated_write_message(self.interface, message)
|
||||
|
||||
query = gen.send(None)
|
||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||
@ -271,7 +271,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
message = MessageWithId(
|
||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||
)
|
||||
gen = thp_v1.write_message(self.interface, message)
|
||||
gen = thp_v1.deprecated_write_message(self.interface, message)
|
||||
|
||||
header = make_header(
|
||||
PLAINTEXT_1,
|
||||
|
Loading…
Reference in New Issue
Block a user