From 8aef86725985381b23c31d5c3c88bb08bbe73128 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 4 Dec 2024 16:21:06 +0100 Subject: [PATCH] refactor(core): move wire buffer handling completely to memory_manager [no changelog] --- core/src/trezor/wire/__init__.py | 14 ++++------- core/src/trezor/wire/thp/channel.py | 2 +- core/src/trezor/wire/thp/channel_manager.py | 8 ++---- core/src/trezor/wire/thp/memory_manager.py | 22 ++++++++-------- core/src/trezor/wire/thp/thp_main.py | 28 ++------------------- core/tests/test_trezor.wire.thp.py | 8 +++--- 6 files changed, 24 insertions(+), 58 deletions(-) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 1bc847f273..287ab3377b 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -42,12 +42,6 @@ from .message_handler import failure # other packages. from .errors import * # isort:skip # noqa: F401,F403 -_PROTOBUF_BUFFER_SIZE = const(8192) - -WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) -if utils.USE_THP: - WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE) - if TYPE_CHECKING: from trezorio import WireInterface from typing import Any, Callable, Coroutine, TypeVar @@ -65,12 +59,12 @@ def setup(iface: WireInterface) -> None: if utils.USE_THP: + # memory_manager is imported to create READ/WRITE buffers + # in more stable area of memory + from .thp import memory_manager # noqa: F401 async def handle_session(iface: WireInterface) -> None: - thp_main.set_read_buffer(WIRE_BUFFER) - thp_main.set_write_buffer(WIRE_BUFFER_2) - # Take a mark of modules that are imported at this point, so we can # roll back and un-import any others. modules = utils.unimport_begin() @@ -91,6 +85,8 @@ if utils.USE_THP: return # pylint: disable=lost-exception else: + _PROTOBUF_BUFFER_SIZE = const(8192) + WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) async def handle_session(iface: WireInterface) -> None: ctx = CodecContext(iface, WIRE_BUFFER) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index e40e0f8a3f..52af91bc20 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -56,7 +56,7 @@ class Channel: self.channel_id: bytes = channel_cache.channel_id # Shared variables - self.buffer: utils.BufferType + self.buffer: utils.BufferType = bytearray(64) self.bytes_read: int = 0 self.expected_payload_length: int = 0 self.is_cont_packet_expected: bool = False diff --git a/core/src/trezor/wire/thp/channel_manager.py b/core/src/trezor/wire/thp/channel_manager.py index de0d5fa364..9a564f2470 100644 --- a/core/src/trezor/wire/thp/channel_manager.py +++ b/core/src/trezor/wire/thp/channel_manager.py @@ -1,7 +1,6 @@ from typing import TYPE_CHECKING from storage import cache_thp -from trezor import utils from . import ChannelState, interface_manager from .channel import Channel @@ -10,18 +9,17 @@ if TYPE_CHECKING: from trezorio import WireInterface -def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel: +def create_new_channel(iface: WireInterface) -> Channel: """ Creates a new channel for the interface `iface` with the buffer `buffer`. """ channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface)) channel = Channel(channel_cache) - channel.set_buffer(buffer) channel.set_channel_state(ChannelState.TH1) return channel -def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: +def load_cached_channels() -> dict[int, Channel]: """ Returns all allocated channels from cache. """ @@ -29,6 +27,4 @@ def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: cached_channels = cache_thp.get_all_allocated_channels() for channel in cached_channels: channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel) - for channel in channels.values(): - channel.set_buffer(buffer) return channels diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index 0a117c16f7..d7fb633134 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -11,6 +11,10 @@ from .writer import ( PACKET_LENGTH, ) +_PROTOBUF_BUFFER_SIZE = 8192 +READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) +WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) + def select_buffer( channel_state: int, @@ -115,22 +119,20 @@ def _get_buffer_for_read( if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.debug(__name__, "Allocating a new buffer") - from .thp_main import get_raw_read_buffer - - if length > len(get_raw_read_buffer()): + if length > len(READ_BUFFER): if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.debug( __name__, "Required length is %d, where raw buffer has capacity only %d", length, - len(get_raw_read_buffer()), + len(READ_BUFFER), ) raise ThpError("Message is too large") try: - payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length] + payload: utils.BufferType = memoryview(READ_BUFFER)[:length] except MemoryError: - payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH] + payload = memoryview(READ_BUFFER)[:PACKET_LENGTH] raise ThpError("Message is too large") return payload @@ -161,15 +163,13 @@ def _get_buffer_for_write( if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.debug(__name__, "Creating a new write buffer from raw write buffer") - from .thp_main import get_raw_write_buffer - - if length > len(get_raw_write_buffer()): + if length > len(WRITE_BUFFER): raise ThpError("Message is too large") try: - payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length] + payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length] except MemoryError: - payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH] + payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH] raise ThpError("Message is too large") return payload diff --git a/core/src/trezor/wire/thp/thp_main.py b/core/src/trezor/wire/thp/thp_main.py index 2381ca0638..a529a22367 100644 --- a/core/src/trezor/wire/thp/thp_main.py +++ b/core/src/trezor/wire/thp/thp_main.py @@ -30,35 +30,12 @@ if TYPE_CHECKING: from trezorio import WireInterface _CID_REQ_PAYLOAD_LENGTH = const(12) -_READ_BUFFER: bytearray -_WRITE_BUFFER: bytearray _CHANNELS: dict[int, Channel] = {} -def set_read_buffer(buffer: bytearray) -> None: - global _READ_BUFFER - _READ_BUFFER = buffer - - -def set_write_buffer(buffer: bytearray) -> None: - global _WRITE_BUFFER - _WRITE_BUFFER = buffer - - -def get_raw_read_buffer() -> bytearray: - global _READ_BUFFER - return _READ_BUFFER - - -def get_raw_write_buffer() -> bytearray: - global _WRITE_BUFFER - return _WRITE_BUFFER - - async def thp_main_loop(iface: WireInterface) -> None: global _CHANNELS - global _READ_BUFFER - _CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER) + _CHANNELS = channel_manager.load_cached_channels() read = loop.wait(iface.iface_num() | io.POLL_READ) @@ -100,7 +77,6 @@ async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None: async def _handle_broadcast( iface: WireInterface, ctrl_byte: int, packet: utils.BufferType ) -> None: - global _READ_BUFFER if ctrl_byte != CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") if __debug__: @@ -114,7 +90,7 @@ async def _handle_broadcast( ): raise ThpError("Checksum is not valid") - new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER) + new_channel: Channel = channel_manager.create_new_channel(iface) cid = int.from_bytes(new_channel.channel_id, "big") _CHANNELS[cid] = new_channel diff --git a/core/tests/test_trezor.wire.thp.py b/core/tests/test_trezor.wire.thp.py index 576ddab4db..95f6cb6322 100644 --- a/core/tests/test_trezor.wire.thp.py +++ b/core/tests/test_trezor.wire.thp.py @@ -28,7 +28,7 @@ if utils.USE_THP: ThpEndRequest, ThpStartPairingRequest, ) - from trezor.wire.thp import thp_main + from trezor.wire.thp import thp_main, memory_manager from trezor.wire.thp import ChannelState, checksum, interface_manager from trezor.wire.thp.crypto import Handshake from trezor.wire.thp.pairing_context import PairingContext @@ -97,10 +97,8 @@ class TestTrezorHostProtocol(unittest.TestCase): def setUp(self): self.interface = MockHID(0xDEADBEEF) - buffer = bytearray(64) - buffer2 = bytearray(256) - thp_main.set_read_buffer(buffer) - thp_main.set_write_buffer(buffer2) + memory_manager.READ_BUFFER = bytearray(64) + memory_manager.WRITE_BUFFER = bytearray(256) interface_manager.decode_iface = thp_common.dummy_decode_iface def test_codec_message(self):