diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 21ffc9f1bd..0da7dc1532 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -357,8 +357,6 @@ if __debug__: async def _no_op(_msg: Any) -> Success: return Success() - WIRE_BUFFER_DEBUG = bytearray(1024) - async def handle_session(iface: WireInterface) -> None: from trezor import protobuf, wire from trezor.wire.codec import codec_v1 @@ -366,7 +364,7 @@ if __debug__: global DEBUG_CONTEXT - DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG) + DEBUG_CONTEXT = ctx = CodecContext(iface, wire.BufferProvider(1024), "Debug") if storage.layout_watcher: try: diff --git a/core/src/session.py b/core/src/session.py index d1bb0628ab..b261d07a0a 100644 --- a/core/src/session.py +++ b/core/src/session.py @@ -6,12 +6,7 @@ from trezor import log, loop, utils, wire, workflow import apps.base import usb -_PROTOBUF_BUFFER_SIZE = const(8192) -USB_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - -if utils.USE_BLE: - BLE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) - +buffer_provider = wire.BufferProvider(8192) apps.base.boot() @@ -30,13 +25,13 @@ apps.base.set_homescreen() workflow.start_default() # initialize the wire codec over USB -wire.setup(usb.iface_wire, USB_BUFFER) +wire.setup(usb.iface_wire, buffer_provider, "USB") if utils.USE_BLE: import bluetooth # initialize the wire codec over BLE - wire.setup(bluetooth.iface_ble, BLE_BUFFER) + wire.setup(bluetooth.iface_ble, buffer_provider, "BLE") # start the event loop loop.run() diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index ac4e98fd67..94e39ff8b9 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -47,13 +47,28 @@ if TYPE_CHECKING: LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) -def setup(iface: WireInterface, buffer: bytearray) -> None: +class BufferProvider: + def __init__(self, size): + self.buf = bytearray(size) + self.owner = None + + def take(self, owner: str) -> bytearray | None: + if self.buf is None: + return None + + buf = self.buf + self.buf = None + self.owner = owner + return buf + + +def setup(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None: """Initialize the wire stack on the provided WireInterface.""" - loop.schedule(handle_session(iface, buffer)) + loop.schedule(handle_session(iface, buffer_provider, name)) -async def handle_session(iface: WireInterface, buffer: bytearray) -> None: - ctx = CodecContext(iface, buffer) +async def handle_session(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None: + ctx = CodecContext(iface, buffer_provider, name) next_msg: protocol_common.Message | None = None # Take a mark of modules that are imported at this point, so we can diff --git a/core/src/trezor/wire/codec/codec_context.py b/core/src/trezor/wire/codec/codec_context.py index 2d5a7b7c9a..54086bc099 100644 --- a/core/src/trezor/wire/codec/codec_context.py +++ b/core/src/trezor/wire/codec/codec_context.py @@ -5,12 +5,12 @@ from storage.cache_common import DataCache, InvalidSessionError from trezor import log, protobuf from trezor.wire.codec import codec_v1 from trezor.wire.context import UnexpectedMessageException -from trezor.wire.protocol_common import Context, Message +from trezor.wire.protocol_common import Context, Message, WireError if TYPE_CHECKING: from typing import TypeVar - from trezor.wire import WireInterface + from trezor.wire import WireInterface, BufferProvider LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) @@ -21,14 +21,29 @@ class CodecContext(Context): def __init__( self, iface: WireInterface, - buffer: bytearray, + buffer_provider: BufferProvider, + name: str, ) -> None: - self.buffer = buffer + self.buffer_provider = buffer_provider + self._buffer = None + self.name = name super().__init__(iface) + def get_buffer(self) -> bytearray: + if self._buffer is not None: + return self._buffer + + self._buffer = self.buffer_provider.take(self.name) + if self._buffer is not None: + return self._buffer + + # The exception should be caught by and handled by `wire.handle_session()` task. + # It doesn't terminate the "blocked" session (to allow sending error responses). + raise WireError(f"{self.buffer_provider.owner} session in progress, {self.name} is blocked") + def read_from_wire(self) -> Awaitable[Message]: """Read a whole message from the wire without parsing it.""" - return codec_v1.read_message(self.iface, self.buffer) + return codec_v1.read_message(self.iface, self.get_buffer) async def read( self, @@ -81,10 +96,15 @@ class CodecContext(Context): msg_size = protobuf.encoded_length(msg) - if msg_size <= len(self.buffer): - # reuse preallocated - buffer = self.buffer + if self._buffer is not None: + buffer = self._buffer else: + # Allow sending small responses (for error reporting when another session is in progress + if msg_size > 128: + raise MemoryError(msg_size) ### FIXME + buffer = bytearray() + + if msg_size > len(buffer): # message is too big, we need to allocate a new buffer buffer = bytearray(msg_size) diff --git a/core/src/trezor/wire/codec/codec_v1.py b/core/src/trezor/wire/codec/codec_v1.py index c1c3b39e7c..bb4146e21e 100644 --- a/core/src/trezor/wire/codec/codec_v1.py +++ b/core/src/trezor/wire/codec/codec_v1.py @@ -7,6 +7,7 @@ from trezor.wire.protocol_common import Message, WireError if TYPE_CHECKING: from trezorio import WireInterface + from typing import Callable _REP_MARKER = const(63) # ord('?') _REP_MAGIC = const(35) # org('#') @@ -19,7 +20,7 @@ class CodecError(WireError): pass -async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: +async def read_message(iface: WireInterface, buffer_getter: Callable[[], bytearray]) -> Message: read = loop.wait(iface.iface_num() | io.POLL_READ) report = bytearray(iface.RX_PACKET_LEN) @@ -33,6 +34,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC: raise CodecError("Invalid magic") + buffer = buffer_getter() # will throw if other session is in progress read_and_throw_away = False if msize > len(buffer):