Fix imports in thp_v1.py

M1nd3r/thp5
M1nd3r 2 months ago
parent bd99a471e4
commit 41551ffffa

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from trezor import utils from trezor import utils
from trezor.wire import codec_v1, thp_v1 from trezor.wire import codec_v1
from trezor.wire.protocol_common import MessageWithId from trezor.wire.protocol_common import MessageWithId
if TYPE_CHECKING: if TYPE_CHECKING:
@ -10,13 +10,12 @@ if TYPE_CHECKING:
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
if utils.USE_THP: if utils.USE_THP:
return await thp_v1.read_message(iface, buffer) raise Exception("THP protocol should be used instead")
return await codec_v1.read_message(iface, buffer) return await codec_v1.read_message(iface, buffer)
async def write_message(iface: WireInterface, message: MessageWithId) -> None: async def write_message(iface: WireInterface, message: MessageWithId) -> None:
if utils.USE_THP: if utils.USE_THP:
await thp_v1.write_message_with_sync_control(iface, message) raise Exception("THP protocol should be used instead")
return
await codec_v1.write_message(iface, message.type, message.data) await codec_v1.write_message(iface, message.type, message.data)
return return

@ -13,11 +13,9 @@ from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH from .checksum import CHECKSUM_LENGTH
from .thp_messages import ( from .thp_messages import (
ACK_MESSAGE, ACK_MESSAGE,
CONT_DATA_OFFSET,
CONTINUATION_PACKET, CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT, ENCRYPTED_TRANSPORT,
HANDSHAKE_INIT, HANDSHAKE_INIT,
INIT_DATA_OFFSET,
) )
from .thp_session import ThpError from .thp_session import ThpError
@ -30,8 +28,12 @@ _MOCK_INTERFACE_HID = b"\x00"
_PUBKEY_LENGTH = const(32) _PUBKEY_LENGTH = const(32)
_REPORT_LENGTH = const(64) INIT_DATA_OFFSET = const(5)
_MAX_PAYLOAD_LEN = const(60000) CONT_DATA_OFFSET = const(3)
REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
class ChannelContext(Context): class ChannelContext(Context):
@ -267,7 +269,7 @@ def _encode_iface(iface: WireInterface) -> bytes:
def _get_buffer_for_payload( def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType: ) -> utils.BufferType:
if payload_length > max_length: if payload_length > max_length:
raise ThpError("Message too large") raise ThpError("Message too large")
@ -276,7 +278,7 @@ def _get_buffer_for_payload(
try: try:
payload: utils.BufferType = bytearray(payload_length) payload: utils.BufferType = bytearray(payload_length)
except MemoryError: except MemoryError:
payload = bytearray(_REPORT_LENGTH) payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large") raise ThpError("Message too large")
return payload return payload

@ -12,8 +12,6 @@ HANDSHAKE_INIT = 0x00
ACK_MESSAGE = 0x20 ACK_MESSAGE = 0x20
_ERROR = 0x41 _ERROR = 0x41
_CHANNEL_ALLOCATION_RES = 0x40 _CHANNEL_ALLOCATION_RES = 0x40
INIT_DATA_OFFSET = const(5)
CONT_DATA_OFFSET = const(3)
class InitHeader: class InitHeader:

@ -9,18 +9,18 @@ from .protocol_common import MessageWithId
from .thp import ChannelState, ack_handler, checksum, thp_messages from .thp import ChannelState, ack_handler, checksum, thp_messages
from .thp import thp_session as THP from .thp import thp_session as THP
from .thp.channel_context import ( from .thp.channel_context import (
_MAX_PAYLOAD_LEN, CONT_DATA_OFFSET,
_REPORT_LENGTH, INIT_DATA_OFFSET,
MAX_PAYLOAD_LEN,
REPORT_LENGTH,
ChannelContext, ChannelContext,
load_cached_channels, load_cached_channels,
) )
from .thp.checksum import CHECKSUM_LENGTH from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import ( from .thp.thp_messages import (
CODEC_V1, CODEC_V1,
CONT_DATA_OFFSET,
CONTINUATION_PACKET, CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT, ENCRYPTED_TRANSPORT,
INIT_DATA_OFFSET,
InitHeader, InitHeader,
InterruptingInitPacket, InterruptingInitPacket,
) )
@ -192,7 +192,7 @@ def _get_loop_wait_read(iface: WireInterface):
def _get_buffer_for_payload( def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType: ) -> utils.BufferType:
if payload_length > max_length: if payload_length > max_length:
raise ThpError("Message too large") raise ThpError("Message too large")
@ -201,7 +201,7 @@ def _get_buffer_for_payload(
try: try:
payload: utils.BufferType = bytearray(payload_length) payload: utils.BufferType = bytearray(payload_length)
except MemoryError: except MemoryError:
payload = bytearray(_REPORT_LENGTH) payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large") raise ThpError("Message too large")
return payload return payload
@ -302,7 +302,7 @@ async def write_to_wire(
payload_length = len(payload) payload_length = len(payload)
# prepare the report buffer with header data # prepare the report buffer with header data
report = bytearray(_REPORT_LENGTH) report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report) header.pack_to_buffer(report)
# write initial report # write initial report

@ -322,7 +322,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# ensure that a message this big won't fit into memory # ensure that a message this big won't fit into memory
# Note: this control is changed, because THP has only 2 byte length field # Note: this control is changed, because THP has only 2 byte length field
self.assertTrue(message_size > thp_v1._MAX_PAYLOAD_LEN) self.assertTrue(message_size > thp_v1.MAX_PAYLOAD_LEN)
# self.assertRaises(MemoryError, bytearray, message_size) # self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(PLAINTEXT_1, COMMON_CID, message_size) header = make_header(PLAINTEXT_1, COMMON_CID, message_size)
packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH)

Loading…
Cancel
Save