mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
refactor(core): remove unused handlers, clean ChannelContext
This commit is contained in:
parent
7c447ac5d1
commit
6f3db981ec
2
core/src/all_modules.py
generated
2
core/src/all_modules.py
generated
@ -215,8 +215,6 @@ trezor.wire.thp.channel_context
|
||||
import trezor.wire.thp.channel_context
|
||||
trezor.wire.thp.checksum
|
||||
import trezor.wire.thp.checksum
|
||||
trezor.wire.thp.packet_handlers
|
||||
import trezor.wire.thp.packet_handlers
|
||||
trezor.wire.thp.session_context
|
||||
import trezor.wire.thp.session_context
|
||||
trezor.wire.thp.thp_messages
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import utils
|
||||
from trezor.wire import codec_v1, thp_v1
|
||||
from trezor.wire import codec_v1
|
||||
from trezor.wire.protocol_common import MessageWithId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -10,13 +10,12 @@ if TYPE_CHECKING:
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
||||
if utils.USE_THP:
|
||||
return await thp_v1.read_message(iface, buffer)
|
||||
raise Exception("THP protocol should be used instead")
|
||||
return await codec_v1.read_message(iface, buffer)
|
||||
|
||||
|
||||
async def write_message(iface: WireInterface, message: MessageWithId) -> None:
|
||||
if utils.USE_THP:
|
||||
await thp_v1.write_message_with_sync_control(iface, message)
|
||||
return
|
||||
raise Exception("THP protocol should be used instead")
|
||||
await codec_v1.write_message(iface, message.type, message.data)
|
||||
return
|
||||
|
@ -1,11 +1,6 @@
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import ( # pyright:ignore[reportShadowedImports]
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
)
|
||||
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
|
||||
|
||||
import usb
|
||||
from storage import cache_thp
|
||||
@ -16,8 +11,6 @@ from ..protocol_common import Context
|
||||
from . import ChannelState, SessionState, checksum
|
||||
from . import thp_session as THP
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
|
||||
# from . import thp_session
|
||||
from .thp_messages import (
|
||||
ACK_MESSAGE,
|
||||
CONTINUATION_PACKET,
|
||||
@ -26,28 +19,21 @@ from .thp_messages import (
|
||||
)
|
||||
from .thp_session import ThpError
|
||||
|
||||
# from .thp_session import SessionState, ThpError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface # type:ignore
|
||||
|
||||
Handler = Callable[
|
||||
[bytes, Any, Any, Any], Coroutine
|
||||
] # TODO Adjust parameters to be more restrictive
|
||||
|
||||
|
||||
_INIT_DATA_OFFSET = const(5)
|
||||
_CONT_DATA_OFFSET = const(3)
|
||||
_INIT_DATA_OFFSET = const(5)
|
||||
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||
|
||||
_WIRE_INTERFACE_USB = b"\x01"
|
||||
_MOCK_INTERFACE_HID = b"\x00"
|
||||
|
||||
_PUBKEY_LENGTH = const(32)
|
||||
|
||||
_REPORT_LENGTH = const(64)
|
||||
_MAX_PAYLOAD_LEN = const(60000)
|
||||
INIT_DATA_OFFSET = const(5)
|
||||
CONT_DATA_OFFSET = const(3)
|
||||
|
||||
|
||||
REPORT_LENGTH = const(64)
|
||||
MAX_PAYLOAD_LEN = const(60000)
|
||||
|
||||
|
||||
class ChannelContext(Context):
|
||||
@ -123,11 +109,11 @@ class ChannelContext(Context):
|
||||
async def _handle_cont_packet(self, packet):
|
||||
if not self.is_cont_packet_expected:
|
||||
return # Continuation packet is not expected, ignoring
|
||||
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
|
||||
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
|
||||
|
||||
async def _handle_completed_message(self):
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
||||
msg_len = payload_length + _INIT_DATA_OFFSET
|
||||
msg_len = payload_length + INIT_DATA_OFFSET
|
||||
if not checksum.is_valid(
|
||||
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
|
||||
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
|
||||
@ -152,7 +138,7 @@ class ChannelContext(Context):
|
||||
"Message received is not a valid handshake init request!"
|
||||
)
|
||||
host_ephemeral_key = bytearray(
|
||||
self.buffer[_INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
|
||||
self.buffer[INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
|
||||
)
|
||||
cache_thp.set_channel_host_ephemeral_key(
|
||||
self.channel_cache, host_ephemeral_key
|
||||
@ -170,7 +156,7 @@ class ChannelContext(Context):
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
self._decrypt_buffer()
|
||||
session_id, message_type = ustruct.unpack(
|
||||
">BH", self.buffer[_INIT_DATA_OFFSET:]
|
||||
">BH", self.buffer[INIT_DATA_OFFSET:]
|
||||
)
|
||||
if session_id not in self.sessions:
|
||||
raise Exception("Unalloacted session")
|
||||
@ -181,15 +167,15 @@ class ChannelContext(Context):
|
||||
|
||||
await self.sessions[session_id].receive_message(
|
||||
message_type,
|
||||
self.buffer[_INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
||||
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
||||
)
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
host_encrypted_static_pubkey = self.buffer[
|
||||
_INIT_DATA_OFFSET : _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
handshake_completion_request_noise_payload = self.buffer[
|
||||
_INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
|
||||
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
|
||||
]
|
||||
print(
|
||||
host_encrypted_static_pubkey,
|
||||
@ -216,45 +202,6 @@ class ChannelContext(Context):
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
def _get_handler(self) -> Handler:
|
||||
state = self.get_channel_state()
|
||||
if state is ChannelState.UNAUTHENTICATED:
|
||||
return self._handler_unauthenticated
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
return self._handler_encrypted_transport
|
||||
raise Exception("Unimplemented situation")
|
||||
|
||||
# Handlers for init packets
|
||||
# TODO adjust
|
||||
async def _handler_encrypted_transport(
|
||||
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
|
||||
) -> None:
|
||||
self.expected_payload_length = payload_length
|
||||
self.bytes_read = 0
|
||||
|
||||
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
|
||||
# TODO Set/Provide different buffer for management session
|
||||
|
||||
if self.expected_payload_length == self.bytes_read:
|
||||
self._finish_message()
|
||||
else:
|
||||
self.is_cont_packet_expected = True
|
||||
|
||||
# TODO adjust
|
||||
async def _handler_unauthenticated(
|
||||
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
|
||||
) -> None:
|
||||
self.expected_payload_length = payload_length
|
||||
self.bytes_read = 0
|
||||
|
||||
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
|
||||
# TODO Set/Provide different buffer for management session
|
||||
|
||||
if self.expected_payload_length == self.bytes_read:
|
||||
self._finish_message()
|
||||
else:
|
||||
self.is_cont_packet_expected = True
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
@ -321,6 +268,24 @@ def _encode_iface(iface: WireInterface) -> bytes:
|
||||
raise Exception("Unknown WireInterface")
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
||||
|
||||
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||
|
||||
@ -335,21 +300,3 @@ def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
|
||||
|
||||
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == ACK_MESSAGE
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(_REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
@ -1,14 +0,0 @@
|
||||
from . import ChannelState
|
||||
from .channel_context import ChannelContext
|
||||
|
||||
|
||||
def getPacketHandler(
|
||||
channel: ChannelContext, packet: bytes
|
||||
): # TODO is the packet bytes or BufferType?
|
||||
if channel.get_channel_state is ChannelState.TH1: # TODO is correct
|
||||
# return handler_TH_1
|
||||
pass
|
||||
|
||||
|
||||
def handler_TH_1(packet):
|
||||
pass
|
@ -9,10 +9,10 @@ from .protocol_common import MessageWithId
|
||||
from .thp import ChannelState, ack_handler, checksum, thp_messages
|
||||
from .thp import thp_session as THP
|
||||
from .thp.channel_context import (
|
||||
_INIT_DATA_OFFSET,
|
||||
_MAX_PAYLOAD_LEN,
|
||||
_REPORT_CONT_DATA_OFFSET,
|
||||
_REPORT_LENGTH,
|
||||
CONT_DATA_OFFSET,
|
||||
INIT_DATA_OFFSET,
|
||||
MAX_PAYLOAD_LEN,
|
||||
REPORT_LENGTH,
|
||||
ChannelContext,
|
||||
load_cached_channels,
|
||||
)
|
||||
@ -192,7 +192,7 @@ def _get_loop_wait_read(iface: WireInterface):
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
@ -201,7 +201,7 @@ def _get_buffer_for_payload(
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(_REPORT_LENGTH)
|
||||
payload = bytearray(REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
@ -213,7 +213,7 @@ async def _buffer_received_data(
|
||||
payload: utils.BufferType, header: InitHeader, iface, report
|
||||
) -> None | InterruptingInitPacket:
|
||||
# buffer the initial data
|
||||
nread = utils.memcpy(payload, 0, report, _INIT_DATA_OFFSET)
|
||||
nread = utils.memcpy(payload, 0, report, INIT_DATA_OFFSET)
|
||||
while nread < header.length:
|
||||
# wait for continuation report
|
||||
report = await _get_loop_wait_read(iface)
|
||||
@ -237,7 +237,7 @@ async def _buffer_received_data(
|
||||
continue
|
||||
|
||||
# buffer the continuation data
|
||||
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
|
||||
nread += utils.memcpy(payload, nread, report, CONT_DATA_OFFSET)
|
||||
|
||||
|
||||
async def write_message_with_sync_control(
|
||||
@ -302,11 +302,11 @@ async def write_to_wire(
|
||||
payload_length = len(payload)
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(_REPORT_LENGTH)
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, _INIT_DATA_OFFSET, payload, 0)
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
await _write_report(loop_write, iface, report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
@ -314,7 +314,7 @@ async def write_to_wire(
|
||||
header.pack_to_cont_buffer(report)
|
||||
|
||||
while nwritten < payload_length:
|
||||
nwritten += utils.memcpy(report, _REPORT_CONT_DATA_OFFSET, payload, nwritten)
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await _write_report(loop_write, iface, report)
|
||||
|
||||
|
||||
|
@ -322,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
|
||||
# ensure that a message this big won't fit into memory
|
||||
# Note: this control is changed, because THP has only 2 byte length field
|
||||
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN)
|
||||
self.assertTrue(message_size > thp_v1.MAX_PAYLOAD_LEN)
|
||||
# self.assertRaises(MemoryError, bytearray, message_size)
|
||||
header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
|
||||
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)
|
||||
|
@ -183,7 +183,7 @@ class Field:
|
||||
|
||||
class _MessageTypeMeta(type):
|
||||
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
|
||||
super().__init__(name, bases, d) # type: ignore [Expected 1 positional argument]
|
||||
super().__init__(name, bases, d)
|
||||
if name != "MessageType":
|
||||
cls.__init__ = MessageType.__init__ # type: ignore ["__init__" is obscured by a declaration of the same name;;Cannot assign member "__init__" for type "_MessageTypeMeta"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user