mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-03 12:00:59 +00:00
refactor(core): move wire buffer handling completely to memory_manager
[no changelog]
This commit is contained in:
parent
4d9699a883
commit
8aef867259
@ -42,12 +42,6 @@ from .message_handler import failure
|
||||
# other packages.
|
||||
from .errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
if utils.USE_THP:
|
||||
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Any, Callable, Coroutine, TypeVar
|
||||
@ -65,12 +59,12 @@ def setup(iface: WireInterface) -> None:
|
||||
|
||||
|
||||
if utils.USE_THP:
|
||||
# memory_manager is imported to create READ/WRITE buffers
|
||||
# in more stable area of memory
|
||||
from .thp import memory_manager # noqa: F401
|
||||
|
||||
async def handle_session(iface: WireInterface) -> None:
|
||||
|
||||
thp_main.set_read_buffer(WIRE_BUFFER)
|
||||
thp_main.set_write_buffer(WIRE_BUFFER_2)
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
# roll back and un-import any others.
|
||||
modules = utils.unimport_begin()
|
||||
@ -91,6 +85,8 @@ if utils.USE_THP:
|
||||
return # pylint: disable=lost-exception
|
||||
|
||||
else:
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
async def handle_session(iface: WireInterface) -> None:
|
||||
ctx = CodecContext(iface, WIRE_BUFFER)
|
||||
|
@ -56,7 +56,7 @@ class Channel:
|
||||
self.channel_id: bytes = channel_cache.channel_id
|
||||
|
||||
# Shared variables
|
||||
self.buffer: utils.BufferType
|
||||
self.buffer: utils.BufferType = bytearray(64)
|
||||
self.bytes_read: int = 0
|
||||
self.expected_payload_length: int = 0
|
||||
self.is_cont_packet_expected: bool = False
|
||||
|
@ -1,7 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp
|
||||
from trezor import utils
|
||||
|
||||
from . import ChannelState, interface_manager
|
||||
from .channel import Channel
|
||||
@ -10,18 +9,17 @@ if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
|
||||
def create_new_channel(iface: WireInterface) -> Channel:
|
||||
"""
|
||||
Creates a new channel for the interface `iface` with the buffer `buffer`.
|
||||
"""
|
||||
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
|
||||
channel = Channel(channel_cache)
|
||||
channel.set_buffer(buffer)
|
||||
channel.set_channel_state(ChannelState.TH1)
|
||||
return channel
|
||||
|
||||
|
||||
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
|
||||
def load_cached_channels() -> dict[int, Channel]:
|
||||
"""
|
||||
Returns all allocated channels from cache.
|
||||
"""
|
||||
@ -29,6 +27,4 @@ def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
|
||||
cached_channels = cache_thp.get_all_allocated_channels()
|
||||
for channel in cached_channels:
|
||||
channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel)
|
||||
for channel in channels.values():
|
||||
channel.set_buffer(buffer)
|
||||
return channels
|
||||
|
@ -11,6 +11,10 @@ from .writer import (
|
||||
PACKET_LENGTH,
|
||||
)
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = 8192
|
||||
READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
|
||||
def select_buffer(
|
||||
channel_state: int,
|
||||
@ -115,22 +119,20 @@ def _get_buffer_for_read(
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Allocating a new buffer")
|
||||
|
||||
from .thp_main import get_raw_read_buffer
|
||||
|
||||
if length > len(get_raw_read_buffer()):
|
||||
if length > len(READ_BUFFER):
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Required length is %d, where raw buffer has capacity only %d",
|
||||
length,
|
||||
len(get_raw_read_buffer()),
|
||||
len(READ_BUFFER),
|
||||
)
|
||||
raise ThpError("Message is too large")
|
||||
|
||||
try:
|
||||
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
|
||||
payload: utils.BufferType = memoryview(READ_BUFFER)[:length]
|
||||
except MemoryError:
|
||||
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
|
||||
payload = memoryview(READ_BUFFER)[:PACKET_LENGTH]
|
||||
raise ThpError("Message is too large")
|
||||
return payload
|
||||
|
||||
@ -161,15 +163,13 @@ def _get_buffer_for_write(
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "Creating a new write buffer from raw write buffer")
|
||||
|
||||
from .thp_main import get_raw_write_buffer
|
||||
|
||||
if length > len(get_raw_write_buffer()):
|
||||
if length > len(WRITE_BUFFER):
|
||||
raise ThpError("Message is too large")
|
||||
|
||||
try:
|
||||
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
|
||||
payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length]
|
||||
except MemoryError:
|
||||
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
|
||||
payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH]
|
||||
raise ThpError("Message is too large")
|
||||
return payload
|
||||
|
||||
|
@ -30,35 +30,12 @@ if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
_CID_REQ_PAYLOAD_LENGTH = const(12)
|
||||
_READ_BUFFER: bytearray
|
||||
_WRITE_BUFFER: bytearray
|
||||
_CHANNELS: dict[int, Channel] = {}
|
||||
|
||||
|
||||
def set_read_buffer(buffer: bytearray) -> None:
|
||||
global _READ_BUFFER
|
||||
_READ_BUFFER = buffer
|
||||
|
||||
|
||||
def set_write_buffer(buffer: bytearray) -> None:
|
||||
global _WRITE_BUFFER
|
||||
_WRITE_BUFFER = buffer
|
||||
|
||||
|
||||
def get_raw_read_buffer() -> bytearray:
|
||||
global _READ_BUFFER
|
||||
return _READ_BUFFER
|
||||
|
||||
|
||||
def get_raw_write_buffer() -> bytearray:
|
||||
global _WRITE_BUFFER
|
||||
return _WRITE_BUFFER
|
||||
|
||||
|
||||
async def thp_main_loop(iface: WireInterface) -> None:
|
||||
global _CHANNELS
|
||||
global _READ_BUFFER
|
||||
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
|
||||
_CHANNELS = channel_manager.load_cached_channels()
|
||||
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
|
||||
@ -100,7 +77,6 @@ async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
|
||||
async def _handle_broadcast(
|
||||
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
|
||||
) -> None:
|
||||
global _READ_BUFFER
|
||||
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
|
||||
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
|
||||
if __debug__:
|
||||
@ -114,7 +90,7 @@ async def _handle_broadcast(
|
||||
):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
|
||||
new_channel: Channel = channel_manager.create_new_channel(iface)
|
||||
cid = int.from_bytes(new_channel.channel_id, "big")
|
||||
_CHANNELS[cid] = new_channel
|
||||
|
||||
|
@ -28,7 +28,7 @@ if utils.USE_THP:
|
||||
ThpEndRequest,
|
||||
ThpStartPairingRequest,
|
||||
)
|
||||
from trezor.wire.thp import thp_main
|
||||
from trezor.wire.thp import thp_main, memory_manager
|
||||
from trezor.wire.thp import ChannelState, checksum, interface_manager
|
||||
from trezor.wire.thp.crypto import Handshake
|
||||
from trezor.wire.thp.pairing_context import PairingContext
|
||||
@ -97,10 +97,8 @@ class TestTrezorHostProtocol(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
buffer = bytearray(64)
|
||||
buffer2 = bytearray(256)
|
||||
thp_main.set_read_buffer(buffer)
|
||||
thp_main.set_write_buffer(buffer2)
|
||||
memory_manager.READ_BUFFER = bytearray(64)
|
||||
memory_manager.WRITE_BUFFER = bytearray(256)
|
||||
interface_manager.decode_iface = thp_common.dummy_decode_iface
|
||||
|
||||
def test_codec_message(self):
|
||||
|
Loading…
Reference in New Issue
Block a user