1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-08 03:16:13 +00:00

fixup! feat(core): add BleInterface to session handling

This commit is contained in:
Roman Zeyde 2025-03-20 19:45:54 +02:00
parent 92bcbd97bc
commit bf0b939c6f
5 changed files with 54 additions and 24 deletions

View File

@ -357,8 +357,6 @@ if __debug__:
async def _no_op(_msg: Any) -> Success:
return Success()
WIRE_BUFFER_DEBUG = bytearray(1024)
async def handle_session(iface: WireInterface) -> None:
from trezor import protobuf, wire
from trezor.wire.codec import codec_v1
@ -366,7 +364,7 @@ if __debug__:
global DEBUG_CONTEXT
DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
DEBUG_CONTEXT = ctx = CodecContext(iface, wire.BufferProvider(1024), "Debug")
if storage.layout_watcher:
try:

View File

@ -6,12 +6,7 @@ from trezor import log, loop, utils, wire, workflow
import apps.base
import usb
_PROTOBUF_BUFFER_SIZE = const(8192)
USB_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
if utils.USE_BLE:
BLE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
buffer_provider = wire.BufferProvider(8192)
apps.base.boot()
@ -30,13 +25,13 @@ apps.base.set_homescreen()
workflow.start_default()
# initialize the wire codec over USB
wire.setup(usb.iface_wire, USB_BUFFER)
wire.setup(usb.iface_wire, buffer_provider, "USB")
if utils.USE_BLE:
import bluetooth
# initialize the wire codec over BLE
wire.setup(bluetooth.iface_ble, BLE_BUFFER)
wire.setup(bluetooth.iface_ble, buffer_provider, "BLE")
# start the event loop
loop.run()

View File

@ -47,13 +47,28 @@ if TYPE_CHECKING:
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
def setup(iface: WireInterface, buffer: bytearray) -> None:
class BufferProvider:
def __init__(self, size):
self.buf = bytearray(size)
self.owner = None
def take(self, owner: str) -> bytearray | None:
if self.buf is None:
return None
buf = self.buf
self.buf = None
self.owner = owner
return buf
def setup(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None:
"""Initialize the wire stack on the provided WireInterface."""
loop.schedule(handle_session(iface, buffer))
loop.schedule(handle_session(iface, buffer_provider, name))
async def handle_session(iface: WireInterface, buffer: bytearray) -> None:
ctx = CodecContext(iface, buffer)
async def handle_session(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None:
ctx = CodecContext(iface, buffer_provider, name)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can

View File

@ -5,12 +5,12 @@ from storage.cache_common import DataCache, InvalidSessionError
from trezor import log, protobuf
from trezor.wire.codec import codec_v1
from trezor.wire.context import UnexpectedMessageException
from trezor.wire.protocol_common import Context, Message
from trezor.wire.protocol_common import Context, Message, WireError
if TYPE_CHECKING:
from typing import TypeVar
from trezor.wire import WireInterface
from trezor.wire import WireInterface, BufferProvider
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
@ -21,14 +21,29 @@ class CodecContext(Context):
def __init__(
self,
iface: WireInterface,
buffer: bytearray,
buffer_provider: BufferProvider,
name: str,
) -> None:
self.buffer = buffer
self.buffer_provider = buffer_provider
self._buffer = None
self.name = name
super().__init__(iface)
def get_buffer(self) -> bytearray:
if self._buffer is not None:
return self._buffer
self._buffer = self.buffer_provider.take(self.name)
if self._buffer is not None:
return self._buffer
# The exception should be caught by and handled by `wire.handle_session()` task.
# It doesn't terminate the "blocked" session (to allow sending error responses).
raise WireError(f"{self.buffer_provider.owner} session in progress, {self.name} is blocked")
def read_from_wire(self) -> Awaitable[Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
return codec_v1.read_message(self.iface, self.get_buffer)
async def read(
self,
@ -81,10 +96,15 @@ class CodecContext(Context):
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
if self._buffer is not None:
buffer = self._buffer
else:
# Allow sending small responses (for error reporting when another session is in progress
if msg_size > 128:
raise MemoryError(msg_size) ### FIXME
buffer = bytearray()
if msg_size > len(buffer):
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)

View File

@ -7,6 +7,7 @@ from trezor.wire.protocol_common import Message, WireError
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Callable
_REP_MARKER = const(63) # ord('?')
_REP_MAGIC = const(35) # org('#')
@ -19,7 +20,7 @@ class CodecError(WireError):
pass
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
async def read_message(iface: WireInterface, buffer_getter: Callable[[], bytearray]) -> Message:
read = loop.wait(iface.iface_num() | io.POLL_READ)
report = bytearray(iface.RX_PACKET_LEN)
@ -33,6 +34,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
raise CodecError("Invalid magic")
buffer = buffer_getter() # will throw if other session is in progress
read_and_throw_away = False
if msize > len(buffer):