1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-13 18:08:10 +00:00

refactor(core): move wire buffer handling completely to memory_manager

[no changelog]
This commit is contained in:
M1nd3r 2024-12-04 16:21:06 +01:00
parent d19b58c1ab
commit ece943d3a6
6 changed files with 24 additions and 58 deletions

View File

@ -42,12 +42,6 @@ from .message_handler import failure
# other packages. # other packages.
from .errors import * # isort:skip # noqa: F401,F403 from .errors import * # isort:skip # noqa: F401,F403
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if utils.USE_THP:
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
from typing import Any, Callable, Coroutine, TypeVar from typing import Any, Callable, Coroutine, TypeVar
@ -65,12 +59,12 @@ def setup(iface: WireInterface) -> None:
if utils.USE_THP: if utils.USE_THP:
# memory_manager is imported to create READ/WRITE buffers
# in more stable area of memory
from .thp import memory_manager # noqa: F401
async def handle_session(iface: WireInterface) -> None: async def handle_session(iface: WireInterface) -> None:
thp_main.set_read_buffer(WIRE_BUFFER)
thp_main.set_write_buffer(WIRE_BUFFER_2)
# Take a mark of modules that are imported at this point, so we can # Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others. # roll back and un-import any others.
modules = utils.unimport_begin() modules = utils.unimport_begin()
@ -91,6 +85,8 @@ if utils.USE_THP:
return # pylint: disable=lost-exception return # pylint: disable=lost-exception
else: else:
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
async def handle_session(iface: WireInterface) -> None: async def handle_session(iface: WireInterface) -> None:
ctx = CodecContext(iface, WIRE_BUFFER) ctx = CodecContext(iface, WIRE_BUFFER)

View File

@ -56,7 +56,7 @@ class Channel:
self.channel_id: bytes = channel_cache.channel_id self.channel_id: bytes = channel_cache.channel_id
# Shared variables # Shared variables
self.buffer: utils.BufferType self.buffer: utils.BufferType = bytearray(64)
self.bytes_read: int = 0 self.bytes_read: int = 0
self.expected_payload_length: int = 0 self.expected_payload_length: int = 0
self.is_cont_packet_expected: bool = False self.is_cont_packet_expected: bool = False

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage import cache_thp from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager from . import ChannelState, interface_manager
from .channel import Channel from .channel import Channel
@ -10,18 +9,17 @@ if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel: def create_new_channel(iface: WireInterface) -> Channel:
""" """
Creates a new channel for the interface `iface` with the buffer `buffer`. Creates a new channel for the interface `iface` with the buffer `buffer`.
""" """
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface)) channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
channel = Channel(channel_cache) channel = Channel(channel_cache)
channel.set_buffer(buffer)
channel.set_channel_state(ChannelState.TH1) channel.set_channel_state(ChannelState.TH1)
return channel return channel
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: def load_cached_channels() -> dict[int, Channel]:
""" """
Returns all allocated channels from cache. Returns all allocated channels from cache.
""" """
@ -29,6 +27,4 @@ def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
cached_channels = cache_thp.get_all_allocated_channels() cached_channels = cache_thp.get_all_allocated_channels()
for channel in cached_channels: for channel in cached_channels:
channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel) channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel)
for channel in channels.values():
channel.set_buffer(buffer)
return channels return channels

View File

@ -11,6 +11,10 @@ from .writer import (
PACKET_LENGTH, PACKET_LENGTH,
) )
_PROTOBUF_BUFFER_SIZE = 8192
READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
def select_buffer( def select_buffer(
channel_state: int, channel_state: int,
@ -115,22 +119,20 @@ def _get_buffer_for_read(
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Allocating a new buffer") log.debug(__name__, "Allocating a new buffer")
from .thp_main import get_raw_read_buffer if length > len(READ_BUFFER):
if length > len(get_raw_read_buffer()):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug( log.debug(
__name__, __name__,
"Required length is %d, where raw buffer has capacity only %d", "Required length is %d, where raw buffer has capacity only %d",
length, length,
len(get_raw_read_buffer()), len(READ_BUFFER),
) )
raise ThpError("Message is too large") raise ThpError("Message is too large")
try: try:
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length] payload: utils.BufferType = memoryview(READ_BUFFER)[:length]
except MemoryError: except MemoryError:
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH] payload = memoryview(READ_BUFFER)[:PACKET_LENGTH]
raise ThpError("Message is too large") raise ThpError("Message is too large")
return payload return payload
@ -161,15 +163,13 @@ def _get_buffer_for_write(
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Creating a new write buffer from raw write buffer") log.debug(__name__, "Creating a new write buffer from raw write buffer")
from .thp_main import get_raw_write_buffer if length > len(WRITE_BUFFER):
if length > len(get_raw_write_buffer()):
raise ThpError("Message is too large") raise ThpError("Message is too large")
try: try:
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length] payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length]
except MemoryError: except MemoryError:
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH] payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH]
raise ThpError("Message is too large") raise ThpError("Message is too large")
return payload return payload

View File

@ -30,35 +30,12 @@ if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
_CID_REQ_PAYLOAD_LENGTH = const(12) _CID_REQ_PAYLOAD_LENGTH = const(12)
_READ_BUFFER: bytearray
_WRITE_BUFFER: bytearray
_CHANNELS: dict[int, Channel] = {} _CHANNELS: dict[int, Channel] = {}
def set_read_buffer(buffer: bytearray) -> None:
global _READ_BUFFER
_READ_BUFFER = buffer
def set_write_buffer(buffer: bytearray) -> None:
global _WRITE_BUFFER
_WRITE_BUFFER = buffer
def get_raw_read_buffer() -> bytearray:
global _READ_BUFFER
return _READ_BUFFER
def get_raw_write_buffer() -> bytearray:
global _WRITE_BUFFER
return _WRITE_BUFFER
async def thp_main_loop(iface: WireInterface) -> None: async def thp_main_loop(iface: WireInterface) -> None:
global _CHANNELS global _CHANNELS
global _READ_BUFFER _CHANNELS = channel_manager.load_cached_channels()
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
read = loop.wait(iface.iface_num() | io.POLL_READ) read = loop.wait(iface.iface_num() | io.POLL_READ)
@ -100,7 +77,6 @@ async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
async def _handle_broadcast( async def _handle_broadcast(
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> None: ) -> None:
global _READ_BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ: if ctrl_byte != CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
if __debug__: if __debug__:
@ -114,7 +90,7 @@ async def _handle_broadcast(
): ):
raise ThpError("Checksum is not valid") raise ThpError("Checksum is not valid")
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER) new_channel: Channel = channel_manager.create_new_channel(iface)
cid = int.from_bytes(new_channel.channel_id, "big") cid = int.from_bytes(new_channel.channel_id, "big")
_CHANNELS[cid] = new_channel _CHANNELS[cid] = new_channel

View File

@ -28,7 +28,7 @@ if utils.USE_THP:
ThpEndRequest, ThpEndRequest,
ThpStartPairingRequest, ThpStartPairingRequest,
) )
from trezor.wire.thp import thp_main from trezor.wire.thp import thp_main, memory_manager
from trezor.wire.thp import ChannelState, checksum, interface_manager from trezor.wire.thp import ChannelState, checksum, interface_manager
from trezor.wire.thp.crypto import Handshake from trezor.wire.thp.crypto import Handshake
from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.pairing_context import PairingContext
@ -97,10 +97,8 @@ class TestTrezorHostProtocol(unittest.TestCase):
def setUp(self): def setUp(self):
self.interface = MockHID(0xDEADBEEF) self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64) memory_manager.READ_BUFFER = bytearray(64)
buffer2 = bytearray(256) memory_manager.WRITE_BUFFER = bytearray(256)
thp_main.set_read_buffer(buffer)
thp_main.set_write_buffer(buffer2)
interface_manager.decode_iface = thp_common.dummy_decode_iface interface_manager.decode_iface = thp_common.dummy_decode_iface
def test_codec_message(self): def test_codec_message(self):