1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-12 17:38:13 +00:00

refactor(core): move wire buffer handling completely to memory_manager

[no changelog]
This commit is contained in:
M1nd3r 2024-12-04 16:21:06 +01:00
parent d19b58c1ab
commit ece943d3a6
6 changed files with 24 additions and 58 deletions

View File

@ -42,12 +42,6 @@ from .message_handler import failure
# other packages.
from .errors import * # isort:skip # noqa: F401,F403
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if utils.USE_THP:
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Coroutine, TypeVar
@ -65,12 +59,12 @@ def setup(iface: WireInterface) -> None:
if utils.USE_THP:
# memory_manager is imported to create READ/WRITE buffers
# in more stable area of memory
from .thp import memory_manager # noqa: F401
async def handle_session(iface: WireInterface) -> None:
thp_main.set_read_buffer(WIRE_BUFFER)
thp_main.set_write_buffer(WIRE_BUFFER_2)
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
@ -91,6 +85,8 @@ if utils.USE_THP:
return # pylint: disable=lost-exception
else:
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
async def handle_session(iface: WireInterface) -> None:
ctx = CodecContext(iface, WIRE_BUFFER)

View File

@ -56,7 +56,7 @@ class Channel:
self.channel_id: bytes = channel_cache.channel_id
# Shared variables
self.buffer: utils.BufferType
self.buffer: utils.BufferType = bytearray(64)
self.bytes_read: int = 0
self.expected_payload_length: int = 0
self.is_cont_packet_expected: bool = False

View File

@ -1,7 +1,6 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager
from .channel import Channel
@ -10,18 +9,17 @@ if TYPE_CHECKING:
from trezorio import WireInterface
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
def create_new_channel(iface: WireInterface) -> Channel:
"""
Creates a new channel for the interface `iface` with the buffer `buffer`.
"""
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
channel = Channel(channel_cache)
channel.set_buffer(buffer)
channel.set_channel_state(ChannelState.TH1)
return channel
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
def load_cached_channels() -> dict[int, Channel]:
"""
Returns all allocated channels from cache.
"""
@ -29,6 +27,4 @@ def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
cached_channels = cache_thp.get_all_allocated_channels()
for channel in cached_channels:
channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel)
for channel in channels.values():
channel.set_buffer(buffer)
return channels

View File

@ -11,6 +11,10 @@ from .writer import (
PACKET_LENGTH,
)
_PROTOBUF_BUFFER_SIZE = 8192
READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
def select_buffer(
channel_state: int,
@ -115,22 +119,20 @@ def _get_buffer_for_read(
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Allocating a new buffer")
from .thp_main import get_raw_read_buffer
if length > len(get_raw_read_buffer()):
if length > len(READ_BUFFER):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"Required length is %d, where raw buffer has capacity only %d",
length,
len(get_raw_read_buffer()),
len(READ_BUFFER),
)
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
payload: utils.BufferType = memoryview(READ_BUFFER)[:length]
except MemoryError:
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
payload = memoryview(READ_BUFFER)[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
@ -161,15 +163,13 @@ def _get_buffer_for_write(
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Creating a new write buffer from raw write buffer")
from .thp_main import get_raw_write_buffer
if length > len(get_raw_write_buffer()):
if length > len(WRITE_BUFFER):
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length]
except MemoryError:
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload

View File

@ -30,35 +30,12 @@ if TYPE_CHECKING:
from trezorio import WireInterface
_CID_REQ_PAYLOAD_LENGTH = const(12)
_READ_BUFFER: bytearray
_WRITE_BUFFER: bytearray
_CHANNELS: dict[int, Channel] = {}
def set_read_buffer(buffer: bytearray) -> None:
global _READ_BUFFER
_READ_BUFFER = buffer
def set_write_buffer(buffer: bytearray) -> None:
global _WRITE_BUFFER
_WRITE_BUFFER = buffer
def get_raw_read_buffer() -> bytearray:
global _READ_BUFFER
return _READ_BUFFER
def get_raw_write_buffer() -> bytearray:
global _WRITE_BUFFER
return _WRITE_BUFFER
async def thp_main_loop(iface: WireInterface) -> None:
global _CHANNELS
global _READ_BUFFER
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
_CHANNELS = channel_manager.load_cached_channels()
read = loop.wait(iface.iface_num() | io.POLL_READ)
@ -100,7 +77,6 @@ async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
async def _handle_broadcast(
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> None:
global _READ_BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
if __debug__:
@ -114,7 +90,7 @@ async def _handle_broadcast(
):
raise ThpError("Checksum is not valid")
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
new_channel: Channel = channel_manager.create_new_channel(iface)
cid = int.from_bytes(new_channel.channel_id, "big")
_CHANNELS[cid] = new_channel

View File

@ -28,7 +28,7 @@ if utils.USE_THP:
ThpEndRequest,
ThpStartPairingRequest,
)
from trezor.wire.thp import thp_main
from trezor.wire.thp import thp_main, memory_manager
from trezor.wire.thp import ChannelState, checksum, interface_manager
from trezor.wire.thp.crypto import Handshake
from trezor.wire.thp.pairing_context import PairingContext
@ -97,10 +97,8 @@ class TestTrezorHostProtocol(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
buffer = bytearray(64)
buffer2 = bytearray(256)
thp_main.set_read_buffer(buffer)
thp_main.set_write_buffer(buffer2)
memory_manager.READ_BUFFER = bytearray(64)
memory_manager.WRITE_BUFFER = bytearray(256)
interface_manager.decode_iface = thp_common.dummy_decode_iface
def test_codec_message(self):