mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-03 12:00:59 +00:00
refactor(core): move wire buffer handling completely to memory_manager
[no changelog]
This commit is contained in:
parent
4d9699a883
commit
8aef867259
@ -42,12 +42,6 @@ from .message_handler import failure
|
|||||||
# other packages.
|
# other packages.
|
||||||
from .errors import * # isort:skip # noqa: F401,F403
|
from .errors import * # isort:skip # noqa: F401,F403
|
||||||
|
|
||||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
|
||||||
|
|
||||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
|
||||||
if utils.USE_THP:
|
|
||||||
WIRE_BUFFER_2 = bytearray(_PROTOBUF_BUFFER_SIZE)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
from typing import Any, Callable, Coroutine, TypeVar
|
from typing import Any, Callable, Coroutine, TypeVar
|
||||||
@ -65,12 +59,12 @@ def setup(iface: WireInterface) -> None:
|
|||||||
|
|
||||||
|
|
||||||
if utils.USE_THP:
|
if utils.USE_THP:
|
||||||
|
# memory_manager is imported to create READ/WRITE buffers
|
||||||
|
# in more stable area of memory
|
||||||
|
from .thp import memory_manager # noqa: F401
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface) -> None:
|
async def handle_session(iface: WireInterface) -> None:
|
||||||
|
|
||||||
thp_main.set_read_buffer(WIRE_BUFFER)
|
|
||||||
thp_main.set_write_buffer(WIRE_BUFFER_2)
|
|
||||||
|
|
||||||
# Take a mark of modules that are imported at this point, so we can
|
# Take a mark of modules that are imported at this point, so we can
|
||||||
# roll back and un-import any others.
|
# roll back and un-import any others.
|
||||||
modules = utils.unimport_begin()
|
modules = utils.unimport_begin()
|
||||||
@ -91,6 +85,8 @@ if utils.USE_THP:
|
|||||||
return # pylint: disable=lost-exception
|
return # pylint: disable=lost-exception
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||||
|
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface) -> None:
|
async def handle_session(iface: WireInterface) -> None:
|
||||||
ctx = CodecContext(iface, WIRE_BUFFER)
|
ctx = CodecContext(iface, WIRE_BUFFER)
|
||||||
|
@ -56,7 +56,7 @@ class Channel:
|
|||||||
self.channel_id: bytes = channel_cache.channel_id
|
self.channel_id: bytes = channel_cache.channel_id
|
||||||
|
|
||||||
# Shared variables
|
# Shared variables
|
||||||
self.buffer: utils.BufferType
|
self.buffer: utils.BufferType = bytearray(64)
|
||||||
self.bytes_read: int = 0
|
self.bytes_read: int = 0
|
||||||
self.expected_payload_length: int = 0
|
self.expected_payload_length: int = 0
|
||||||
self.is_cont_packet_expected: bool = False
|
self.is_cont_packet_expected: bool = False
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from storage import cache_thp
|
from storage import cache_thp
|
||||||
from trezor import utils
|
|
||||||
|
|
||||||
from . import ChannelState, interface_manager
|
from . import ChannelState, interface_manager
|
||||||
from .channel import Channel
|
from .channel import Channel
|
||||||
@ -10,18 +9,17 @@ if TYPE_CHECKING:
|
|||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
|
||||||
|
|
||||||
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> Channel:
|
def create_new_channel(iface: WireInterface) -> Channel:
|
||||||
"""
|
"""
|
||||||
Creates a new channel for the interface `iface` with the buffer `buffer`.
|
Creates a new channel for the interface `iface` with the buffer `buffer`.
|
||||||
"""
|
"""
|
||||||
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
|
channel_cache = cache_thp.get_new_channel(interface_manager.encode_iface(iface))
|
||||||
channel = Channel(channel_cache)
|
channel = Channel(channel_cache)
|
||||||
channel.set_buffer(buffer)
|
|
||||||
channel.set_channel_state(ChannelState.TH1)
|
channel.set_channel_state(ChannelState.TH1)
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
|
|
||||||
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
|
def load_cached_channels() -> dict[int, Channel]:
|
||||||
"""
|
"""
|
||||||
Returns all allocated channels from cache.
|
Returns all allocated channels from cache.
|
||||||
"""
|
"""
|
||||||
@ -29,6 +27,4 @@ def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]:
|
|||||||
cached_channels = cache_thp.get_all_allocated_channels()
|
cached_channels = cache_thp.get_all_allocated_channels()
|
||||||
for channel in cached_channels:
|
for channel in cached_channels:
|
||||||
channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel)
|
channels[int.from_bytes(channel.channel_id, "big")] = Channel(channel)
|
||||||
for channel in channels.values():
|
|
||||||
channel.set_buffer(buffer)
|
|
||||||
return channels
|
return channels
|
||||||
|
@ -11,6 +11,10 @@ from .writer import (
|
|||||||
PACKET_LENGTH,
|
PACKET_LENGTH,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_PROTOBUF_BUFFER_SIZE = 8192
|
||||||
|
READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
|
WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
|
|
||||||
|
|
||||||
def select_buffer(
|
def select_buffer(
|
||||||
channel_state: int,
|
channel_state: int,
|
||||||
@ -115,22 +119,20 @@ def _get_buffer_for_read(
|
|||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Allocating a new buffer")
|
log.debug(__name__, "Allocating a new buffer")
|
||||||
|
|
||||||
from .thp_main import get_raw_read_buffer
|
if length > len(READ_BUFFER):
|
||||||
|
|
||||||
if length > len(get_raw_read_buffer()):
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"Required length is %d, where raw buffer has capacity only %d",
|
"Required length is %d, where raw buffer has capacity only %d",
|
||||||
length,
|
length,
|
||||||
len(get_raw_read_buffer()),
|
len(READ_BUFFER),
|
||||||
)
|
)
|
||||||
raise ThpError("Message is too large")
|
raise ThpError("Message is too large")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload: utils.BufferType = memoryview(get_raw_read_buffer())[:length]
|
payload: utils.BufferType = memoryview(READ_BUFFER)[:length]
|
||||||
except MemoryError:
|
except MemoryError:
|
||||||
payload = memoryview(get_raw_read_buffer())[:PACKET_LENGTH]
|
payload = memoryview(READ_BUFFER)[:PACKET_LENGTH]
|
||||||
raise ThpError("Message is too large")
|
raise ThpError("Message is too large")
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
@ -161,15 +163,13 @@ def _get_buffer_for_write(
|
|||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, "Creating a new write buffer from raw write buffer")
|
log.debug(__name__, "Creating a new write buffer from raw write buffer")
|
||||||
|
|
||||||
from .thp_main import get_raw_write_buffer
|
if length > len(WRITE_BUFFER):
|
||||||
|
|
||||||
if length > len(get_raw_write_buffer()):
|
|
||||||
raise ThpError("Message is too large")
|
raise ThpError("Message is too large")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload: utils.BufferType = memoryview(get_raw_write_buffer())[:length]
|
payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length]
|
||||||
except MemoryError:
|
except MemoryError:
|
||||||
payload = memoryview(get_raw_write_buffer())[:PACKET_LENGTH]
|
payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH]
|
||||||
raise ThpError("Message is too large")
|
raise ThpError("Message is too large")
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
@ -30,35 +30,12 @@ if TYPE_CHECKING:
|
|||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
|
||||||
_CID_REQ_PAYLOAD_LENGTH = const(12)
|
_CID_REQ_PAYLOAD_LENGTH = const(12)
|
||||||
_READ_BUFFER: bytearray
|
|
||||||
_WRITE_BUFFER: bytearray
|
|
||||||
_CHANNELS: dict[int, Channel] = {}
|
_CHANNELS: dict[int, Channel] = {}
|
||||||
|
|
||||||
|
|
||||||
def set_read_buffer(buffer: bytearray) -> None:
|
|
||||||
global _READ_BUFFER
|
|
||||||
_READ_BUFFER = buffer
|
|
||||||
|
|
||||||
|
|
||||||
def set_write_buffer(buffer: bytearray) -> None:
|
|
||||||
global _WRITE_BUFFER
|
|
||||||
_WRITE_BUFFER = buffer
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_read_buffer() -> bytearray:
|
|
||||||
global _READ_BUFFER
|
|
||||||
return _READ_BUFFER
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_write_buffer() -> bytearray:
|
|
||||||
global _WRITE_BUFFER
|
|
||||||
return _WRITE_BUFFER
|
|
||||||
|
|
||||||
|
|
||||||
async def thp_main_loop(iface: WireInterface) -> None:
|
async def thp_main_loop(iface: WireInterface) -> None:
|
||||||
global _CHANNELS
|
global _CHANNELS
|
||||||
global _READ_BUFFER
|
_CHANNELS = channel_manager.load_cached_channels()
|
||||||
_CHANNELS = channel_manager.load_cached_channels(_READ_BUFFER)
|
|
||||||
|
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
@ -100,7 +77,6 @@ async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
|
|||||||
async def _handle_broadcast(
|
async def _handle_broadcast(
|
||||||
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
|
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
|
||||||
) -> None:
|
) -> None:
|
||||||
global _READ_BUFFER
|
|
||||||
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
|
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
|
||||||
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
|
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
|
||||||
if __debug__:
|
if __debug__:
|
||||||
@ -114,7 +90,7 @@ async def _handle_broadcast(
|
|||||||
):
|
):
|
||||||
raise ThpError("Checksum is not valid")
|
raise ThpError("Checksum is not valid")
|
||||||
|
|
||||||
new_channel: Channel = channel_manager.create_new_channel(iface, _READ_BUFFER)
|
new_channel: Channel = channel_manager.create_new_channel(iface)
|
||||||
cid = int.from_bytes(new_channel.channel_id, "big")
|
cid = int.from_bytes(new_channel.channel_id, "big")
|
||||||
_CHANNELS[cid] = new_channel
|
_CHANNELS[cid] = new_channel
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ if utils.USE_THP:
|
|||||||
ThpEndRequest,
|
ThpEndRequest,
|
||||||
ThpStartPairingRequest,
|
ThpStartPairingRequest,
|
||||||
)
|
)
|
||||||
from trezor.wire.thp import thp_main
|
from trezor.wire.thp import thp_main, memory_manager
|
||||||
from trezor.wire.thp import ChannelState, checksum, interface_manager
|
from trezor.wire.thp import ChannelState, checksum, interface_manager
|
||||||
from trezor.wire.thp.crypto import Handshake
|
from trezor.wire.thp.crypto import Handshake
|
||||||
from trezor.wire.thp.pairing_context import PairingContext
|
from trezor.wire.thp.pairing_context import PairingContext
|
||||||
@ -97,10 +97,8 @@ class TestTrezorHostProtocol(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.interface = MockHID(0xDEADBEEF)
|
self.interface = MockHID(0xDEADBEEF)
|
||||||
buffer = bytearray(64)
|
memory_manager.READ_BUFFER = bytearray(64)
|
||||||
buffer2 = bytearray(256)
|
memory_manager.WRITE_BUFFER = bytearray(256)
|
||||||
thp_main.set_read_buffer(buffer)
|
|
||||||
thp_main.set_write_buffer(buffer2)
|
|
||||||
interface_manager.decode_iface = thp_common.dummy_decode_iface
|
interface_manager.decode_iface = thp_common.dummy_decode_iface
|
||||||
|
|
||||||
def test_codec_message(self):
|
def test_codec_message(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user