refactor(core): remove unused handlers, clean ChannelContext

M1nd3r/thp6
M1nd3r 2 months ago
parent 7c447ac5d1
commit 6f3db981ec

@ -215,8 +215,6 @@ trezor.wire.thp.channel_context
import trezor.wire.thp.channel_context
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.packet_handlers
import trezor.wire.thp.packet_handlers
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.thp_messages

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING
from trezor import utils
from trezor.wire import codec_v1, thp_v1
from trezor.wire import codec_v1
from trezor.wire.protocol_common import MessageWithId
if TYPE_CHECKING:
@ -10,13 +10,12 @@ if TYPE_CHECKING:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
if utils.USE_THP:
return await thp_v1.read_message(iface, buffer)
raise Exception("THP protocol should be used instead")
return await codec_v1.read_message(iface, buffer)
async def write_message(iface: WireInterface, message: MessageWithId) -> None:
if utils.USE_THP:
await thp_v1.write_message_with_sync_control(iface, message)
return
raise Exception("THP protocol should be used instead")
await codec_v1.write_message(iface, message.type, message.data)
return

@ -1,11 +1,6 @@
import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import ( # pyright:ignore[reportShadowedImports]
TYPE_CHECKING,
Any,
Callable,
Coroutine,
)
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
import usb
from storage import cache_thp
@ -16,8 +11,6 @@ from ..protocol_common import Context
from . import ChannelState, SessionState, checksum
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
# from . import thp_session
from .thp_messages import (
ACK_MESSAGE,
CONTINUATION_PACKET,
@ -26,28 +19,21 @@ from .thp_messages import (
)
from .thp_session import ThpError
# from .thp_session import SessionState, ThpError
if TYPE_CHECKING:
from trezorio import WireInterface # type:ignore
Handler = Callable[
[bytes, Any, Any, Any], Coroutine
] # TODO Adjust parameters to be more restrictive
_INIT_DATA_OFFSET = const(5)
_CONT_DATA_OFFSET = const(3)
_INIT_DATA_OFFSET = const(5)
_REPORT_CONT_DATA_OFFSET = const(3)
_WIRE_INTERFACE_USB = b"\x01"
_MOCK_INTERFACE_HID = b"\x00"
_PUBKEY_LENGTH = const(32)
_REPORT_LENGTH = const(64)
_MAX_PAYLOAD_LEN = const(60000)
INIT_DATA_OFFSET = const(5)
CONT_DATA_OFFSET = const(3)
REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
class ChannelContext(Context):
@ -123,11 +109,11 @@ class ChannelContext(Context):
async def _handle_cont_packet(self, packet):
if not self.is_cont_packet_expected:
return # Continuation packet is not expected, ignoring
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
async def _handle_completed_message(self):
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
msg_len = payload_length + _INIT_DATA_OFFSET
msg_len = payload_length + INIT_DATA_OFFSET
if not checksum.is_valid(
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
@ -152,7 +138,7 @@ class ChannelContext(Context):
"Message received is not a valid handshake init request!"
)
host_ephemeral_key = bytearray(
self.buffer[_INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
self.buffer[INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(
self.channel_cache, host_ephemeral_key
@ -170,7 +156,7 @@ class ChannelContext(Context):
if state is ChannelState.ENCRYPTED_TRANSPORT:
self._decrypt_buffer()
session_id, message_type = ustruct.unpack(
">BH", self.buffer[_INIT_DATA_OFFSET:]
">BH", self.buffer[INIT_DATA_OFFSET:]
)
if session_id not in self.sessions:
raise Exception("Unalloacted session")
@ -181,15 +167,15 @@ class ChannelContext(Context):
await self.sessions[session_id].receive_message(
message_type,
self.buffer[_INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
)
if state is ChannelState.TH2:
host_encrypted_static_pubkey = self.buffer[
_INIT_DATA_OFFSET : _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = self.buffer[
_INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
]
print(
host_encrypted_static_pubkey,
@ -216,45 +202,6 @@ class ChannelContext(Context):
self.expected_payload_length = 0
self.is_cont_packet_expected = False
def _get_handler(self) -> Handler:
state = self.get_channel_state()
if state is ChannelState.UNAUTHENTICATED:
return self._handler_unauthenticated
if state is ChannelState.ENCRYPTED_TRANSPORT:
return self._handler_encrypted_transport
raise Exception("Unimplemented situation")
# Handlers for init packets
# TODO adjust
async def _handler_encrypted_transport(
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
) -> None:
self.expected_payload_length = payload_length
self.bytes_read = 0
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
# TODO Set/Provide different buffer for management session
if self.expected_payload_length == self.bytes_read:
self._finish_message()
else:
self.is_cont_packet_expected = True
# TODO adjust
async def _handler_unauthenticated(
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
) -> None:
self.expected_payload_length = payload_length
self.bytes_read = 0
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
# TODO Set/Provide different buffer for management session
if self.expected_payload_length == self.bytes_read:
self._finish_message()
else:
self.is_cont_packet_expected = True
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
@ -321,24 +268,8 @@ def _encode_iface(iface: WireInterface) -> bytes:
raise Exception("Unknown WireInterface")
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
@ -347,9 +278,25 @@ def _get_buffer_for_payload(
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(_REPORT_LENGTH)
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:payload_length]
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE

@ -1,14 +0,0 @@
from . import ChannelState
from .channel_context import ChannelContext
def getPacketHandler(
channel: ChannelContext, packet: bytes
): # TODO is the packet bytes or BufferType?
if channel.get_channel_state is ChannelState.TH1: # TODO is correct
# return handler_TH_1
pass
def handler_TH_1(packet):
pass

@ -9,10 +9,10 @@ from .protocol_common import MessageWithId
from .thp import ChannelState, ack_handler, checksum, thp_messages
from .thp import thp_session as THP
from .thp.channel_context import (
_INIT_DATA_OFFSET,
_MAX_PAYLOAD_LEN,
_REPORT_CONT_DATA_OFFSET,
_REPORT_LENGTH,
CONT_DATA_OFFSET,
INIT_DATA_OFFSET,
MAX_PAYLOAD_LEN,
REPORT_LENGTH,
ChannelContext,
load_cached_channels,
)
@ -192,7 +192,7 @@ def _get_loop_wait_read(iface: WireInterface):
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
@ -201,7 +201,7 @@ def _get_buffer_for_payload(
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(_REPORT_LENGTH)
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
@ -213,7 +213,7 @@ async def _buffer_received_data(
payload: utils.BufferType, header: InitHeader, iface, report
) -> None | InterruptingInitPacket:
# buffer the initial data
nread = utils.memcpy(payload, 0, report, _INIT_DATA_OFFSET)
nread = utils.memcpy(payload, 0, report, INIT_DATA_OFFSET)
while nread < header.length:
# wait for continuation report
report = await _get_loop_wait_read(iface)
@ -237,7 +237,7 @@ async def _buffer_received_data(
continue
# buffer the continuation data
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
nread += utils.memcpy(payload, nread, report, CONT_DATA_OFFSET)
async def write_message_with_sync_control(
@ -302,11 +302,11 @@ async def write_to_wire(
payload_length = len(payload)
# prepare the report buffer with header data
report = bytearray(_REPORT_LENGTH)
report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, _INIT_DATA_OFFSET, payload, 0)
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await _write_report(loop_write, iface, report)
# if we have more data to write, use continuation reports for it
@ -314,7 +314,7 @@ async def write_to_wire(
header.pack_to_cont_buffer(report)
while nwritten < payload_length:
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await _write_report(loop_write, iface, report)

@ -322,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN)
self.assertTrue(message_size > thp_v1.MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)

@ -183,7 +183,7 @@ class Field:
class _MessageTypeMeta(type):
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
super().__init__(name, bases, d) # type: ignore [Expected 1 positional argument]
super().__init__(name, bases, d)
if name != "MessageType":
cls.__init__ = MessageType.__init__ # type: ignore ["__init__" is obscured by a declaration of the same name;;Cannot assign member "__init__" for type "_MessageTypeMeta"]

Loading…
Cancel
Save