mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-04-08 03:16:13 +00:00
fixup! feat(core): add BleInterface
to session handling
This commit is contained in:
parent
92bcbd97bc
commit
bf0b939c6f
@ -357,8 +357,6 @@ if __debug__:
|
||||
async def _no_op(_msg: Any) -> Success:
|
||||
return Success()
|
||||
|
||||
WIRE_BUFFER_DEBUG = bytearray(1024)
|
||||
|
||||
async def handle_session(iface: WireInterface) -> None:
|
||||
from trezor import protobuf, wire
|
||||
from trezor.wire.codec import codec_v1
|
||||
@ -366,7 +364,7 @@ if __debug__:
|
||||
|
||||
global DEBUG_CONTEXT
|
||||
|
||||
DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
|
||||
DEBUG_CONTEXT = ctx = CodecContext(iface, wire.BufferProvider(1024), "Debug")
|
||||
|
||||
if storage.layout_watcher:
|
||||
try:
|
||||
|
@ -6,12 +6,7 @@ from trezor import log, loop, utils, wire, workflow
|
||||
import apps.base
|
||||
import usb
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
USB_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
if utils.USE_BLE:
|
||||
BLE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
buffer_provider = wire.BufferProvider(8192)
|
||||
|
||||
apps.base.boot()
|
||||
|
||||
@ -30,13 +25,13 @@ apps.base.set_homescreen()
|
||||
workflow.start_default()
|
||||
|
||||
# initialize the wire codec over USB
|
||||
wire.setup(usb.iface_wire, USB_BUFFER)
|
||||
wire.setup(usb.iface_wire, buffer_provider, "USB")
|
||||
|
||||
if utils.USE_BLE:
|
||||
import bluetooth
|
||||
|
||||
# initialize the wire codec over BLE
|
||||
wire.setup(bluetooth.iface_ble, BLE_BUFFER)
|
||||
wire.setup(bluetooth.iface_ble, buffer_provider, "BLE")
|
||||
|
||||
# start the event loop
|
||||
loop.run()
|
||||
|
@ -47,13 +47,28 @@ if TYPE_CHECKING:
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
def setup(iface: WireInterface, buffer: bytearray) -> None:
|
||||
class BufferProvider:
|
||||
def __init__(self, size):
|
||||
self.buf = bytearray(size)
|
||||
self.owner = None
|
||||
|
||||
def take(self, owner: str) -> bytearray | None:
|
||||
if self.buf is None:
|
||||
return None
|
||||
|
||||
buf = self.buf
|
||||
self.buf = None
|
||||
self.owner = owner
|
||||
return buf
|
||||
|
||||
|
||||
def setup(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None:
|
||||
"""Initialize the wire stack on the provided WireInterface."""
|
||||
loop.schedule(handle_session(iface, buffer))
|
||||
loop.schedule(handle_session(iface, buffer_provider, name))
|
||||
|
||||
|
||||
async def handle_session(iface: WireInterface, buffer: bytearray) -> None:
|
||||
ctx = CodecContext(iface, buffer)
|
||||
async def handle_session(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None:
|
||||
ctx = CodecContext(iface, buffer_provider, name)
|
||||
next_msg: protocol_common.Message | None = None
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
|
@ -5,12 +5,12 @@ from storage.cache_common import DataCache, InvalidSessionError
|
||||
from trezor import log, protobuf
|
||||
from trezor.wire.codec import codec_v1
|
||||
from trezor.wire.context import UnexpectedMessageException
|
||||
from trezor.wire.protocol_common import Context, Message
|
||||
from trezor.wire.protocol_common import Context, Message, WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
|
||||
from trezor.wire import WireInterface
|
||||
from trezor.wire import WireInterface, BufferProvider
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
@ -21,14 +21,29 @@ class CodecContext(Context):
|
||||
def __init__(
|
||||
self,
|
||||
iface: WireInterface,
|
||||
buffer: bytearray,
|
||||
buffer_provider: BufferProvider,
|
||||
name: str,
|
||||
) -> None:
|
||||
self.buffer = buffer
|
||||
self.buffer_provider = buffer_provider
|
||||
self._buffer = None
|
||||
self.name = name
|
||||
super().__init__(iface)
|
||||
|
||||
def get_buffer(self) -> bytearray:
|
||||
if self._buffer is not None:
|
||||
return self._buffer
|
||||
|
||||
self._buffer = self.buffer_provider.take(self.name)
|
||||
if self._buffer is not None:
|
||||
return self._buffer
|
||||
|
||||
# The exception should be caught by and handled by `wire.handle_session()` task.
|
||||
# It doesn't terminate the "blocked" session (to allow sending error responses).
|
||||
raise WireError(f"{self.buffer_provider.owner} session in progress, {self.name} is blocked")
|
||||
|
||||
def read_from_wire(self) -> Awaitable[Message]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
return codec_v1.read_message(self.iface, self.get_buffer)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
@ -81,10 +96,15 @@ class CodecContext(Context):
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer = self.buffer
|
||||
if self._buffer is not None:
|
||||
buffer = self._buffer
|
||||
else:
|
||||
# Allow sending small responses (for error reporting when another session is in progress
|
||||
if msg_size > 128:
|
||||
raise MemoryError(msg_size) ### FIXME
|
||||
buffer = bytearray()
|
||||
|
||||
if msg_size > len(buffer):
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from trezor.wire.protocol_common import Message, WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Callable
|
||||
|
||||
_REP_MARKER = const(63) # ord('?')
|
||||
_REP_MAGIC = const(35) # org('#')
|
||||
@ -19,7 +20,7 @@ class CodecError(WireError):
|
||||
pass
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
||||
async def read_message(iface: WireInterface, buffer_getter: Callable[[], bytearray]) -> Message:
|
||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||
report = bytearray(iface.RX_PACKET_LEN)
|
||||
|
||||
@ -33,6 +34,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||
raise CodecError("Invalid magic")
|
||||
|
||||
buffer = buffer_getter() # will throw if other session is in progress
|
||||
read_and_throw_away = False
|
||||
|
||||
if msize > len(buffer):
|
||||
|
Loading…
Reference in New Issue
Block a user