mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-17 20:18:10 +00:00
fixup! feat(core): add BleInterface
to session handling
This commit is contained in:
parent
92bcbd97bc
commit
bf0b939c6f
@ -357,8 +357,6 @@ if __debug__:
|
|||||||
async def _no_op(_msg: Any) -> Success:
|
async def _no_op(_msg: Any) -> Success:
|
||||||
return Success()
|
return Success()
|
||||||
|
|
||||||
WIRE_BUFFER_DEBUG = bytearray(1024)
|
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface) -> None:
|
async def handle_session(iface: WireInterface) -> None:
|
||||||
from trezor import protobuf, wire
|
from trezor import protobuf, wire
|
||||||
from trezor.wire.codec import codec_v1
|
from trezor.wire.codec import codec_v1
|
||||||
@ -366,7 +364,7 @@ if __debug__:
|
|||||||
|
|
||||||
global DEBUG_CONTEXT
|
global DEBUG_CONTEXT
|
||||||
|
|
||||||
DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
|
DEBUG_CONTEXT = ctx = CodecContext(iface, wire.BufferProvider(1024), "Debug")
|
||||||
|
|
||||||
if storage.layout_watcher:
|
if storage.layout_watcher:
|
||||||
try:
|
try:
|
||||||
|
@ -6,12 +6,7 @@ from trezor import log, loop, utils, wire, workflow
|
|||||||
import apps.base
|
import apps.base
|
||||||
import usb
|
import usb
|
||||||
|
|
||||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
buffer_provider = wire.BufferProvider(8192)
|
||||||
USB_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
|
||||||
|
|
||||||
if utils.USE_BLE:
|
|
||||||
BLE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
|
||||||
|
|
||||||
|
|
||||||
apps.base.boot()
|
apps.base.boot()
|
||||||
|
|
||||||
@ -30,13 +25,13 @@ apps.base.set_homescreen()
|
|||||||
workflow.start_default()
|
workflow.start_default()
|
||||||
|
|
||||||
# initialize the wire codec over USB
|
# initialize the wire codec over USB
|
||||||
wire.setup(usb.iface_wire, USB_BUFFER)
|
wire.setup(usb.iface_wire, buffer_provider, "USB")
|
||||||
|
|
||||||
if utils.USE_BLE:
|
if utils.USE_BLE:
|
||||||
import bluetooth
|
import bluetooth
|
||||||
|
|
||||||
# initialize the wire codec over BLE
|
# initialize the wire codec over BLE
|
||||||
wire.setup(bluetooth.iface_ble, BLE_BUFFER)
|
wire.setup(bluetooth.iface_ble, buffer_provider, "BLE")
|
||||||
|
|
||||||
# start the event loop
|
# start the event loop
|
||||||
loop.run()
|
loop.run()
|
||||||
|
@ -47,13 +47,28 @@ if TYPE_CHECKING:
|
|||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
def setup(iface: WireInterface, buffer: bytearray) -> None:
|
class BufferProvider:
|
||||||
|
def __init__(self, size):
|
||||||
|
self.buf = bytearray(size)
|
||||||
|
self.owner = None
|
||||||
|
|
||||||
|
def take(self, owner: str) -> bytearray | None:
|
||||||
|
if self.buf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
buf = self.buf
|
||||||
|
self.buf = None
|
||||||
|
self.owner = owner
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
def setup(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None:
|
||||||
"""Initialize the wire stack on the provided WireInterface."""
|
"""Initialize the wire stack on the provided WireInterface."""
|
||||||
loop.schedule(handle_session(iface, buffer))
|
loop.schedule(handle_session(iface, buffer_provider, name))
|
||||||
|
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface, buffer: bytearray) -> None:
|
async def handle_session(iface: WireInterface, buffer_provider: BufferProvider, name: str) -> None:
|
||||||
ctx = CodecContext(iface, buffer)
|
ctx = CodecContext(iface, buffer_provider, name)
|
||||||
next_msg: protocol_common.Message | None = None
|
next_msg: protocol_common.Message | None = None
|
||||||
|
|
||||||
# Take a mark of modules that are imported at this point, so we can
|
# Take a mark of modules that are imported at this point, so we can
|
||||||
|
@ -5,12 +5,12 @@ from storage.cache_common import DataCache, InvalidSessionError
|
|||||||
from trezor import log, protobuf
|
from trezor import log, protobuf
|
||||||
from trezor.wire.codec import codec_v1
|
from trezor.wire.codec import codec_v1
|
||||||
from trezor.wire.context import UnexpectedMessageException
|
from trezor.wire.context import UnexpectedMessageException
|
||||||
from trezor.wire.protocol_common import Context, Message
|
from trezor.wire.protocol_common import Context, Message, WireError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from trezor.wire import WireInterface
|
from trezor.wire import WireInterface, BufferProvider
|
||||||
|
|
||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
@ -21,14 +21,29 @@ class CodecContext(Context):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
iface: WireInterface,
|
iface: WireInterface,
|
||||||
buffer: bytearray,
|
buffer_provider: BufferProvider,
|
||||||
|
name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.buffer = buffer
|
self.buffer_provider = buffer_provider
|
||||||
|
self._buffer = None
|
||||||
|
self.name = name
|
||||||
super().__init__(iface)
|
super().__init__(iface)
|
||||||
|
|
||||||
|
def get_buffer(self) -> bytearray:
|
||||||
|
if self._buffer is not None:
|
||||||
|
return self._buffer
|
||||||
|
|
||||||
|
self._buffer = self.buffer_provider.take(self.name)
|
||||||
|
if self._buffer is not None:
|
||||||
|
return self._buffer
|
||||||
|
|
||||||
|
# The exception should be caught by and handled by `wire.handle_session()` task.
|
||||||
|
# It doesn't terminate the "blocked" session (to allow sending error responses).
|
||||||
|
raise WireError(f"{self.buffer_provider.owner} session in progress, {self.name} is blocked")
|
||||||
|
|
||||||
def read_from_wire(self) -> Awaitable[Message]:
|
def read_from_wire(self) -> Awaitable[Message]:
|
||||||
"""Read a whole message from the wire without parsing it."""
|
"""Read a whole message from the wire without parsing it."""
|
||||||
return codec_v1.read_message(self.iface, self.buffer)
|
return codec_v1.read_message(self.iface, self.get_buffer)
|
||||||
|
|
||||||
async def read(
|
async def read(
|
||||||
self,
|
self,
|
||||||
@ -81,10 +96,15 @@ class CodecContext(Context):
|
|||||||
|
|
||||||
msg_size = protobuf.encoded_length(msg)
|
msg_size = protobuf.encoded_length(msg)
|
||||||
|
|
||||||
if msg_size <= len(self.buffer):
|
if self._buffer is not None:
|
||||||
# reuse preallocated
|
buffer = self._buffer
|
||||||
buffer = self.buffer
|
|
||||||
else:
|
else:
|
||||||
|
# Allow sending small responses (for error reporting when another session is in progress
|
||||||
|
if msg_size > 128:
|
||||||
|
raise MemoryError(msg_size) ### FIXME
|
||||||
|
buffer = bytearray()
|
||||||
|
|
||||||
|
if msg_size > len(buffer):
|
||||||
# message is too big, we need to allocate a new buffer
|
# message is too big, we need to allocate a new buffer
|
||||||
buffer = bytearray(msg_size)
|
buffer = bytearray(msg_size)
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from trezor.wire.protocol_common import Message, WireError
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
_REP_MARKER = const(63) # ord('?')
|
_REP_MARKER = const(63) # ord('?')
|
||||||
_REP_MAGIC = const(35) # org('#')
|
_REP_MAGIC = const(35) # org('#')
|
||||||
@ -19,7 +20,7 @@ class CodecError(WireError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
async def read_message(iface: WireInterface, buffer_getter: Callable[[], bytearray]) -> Message:
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
report = bytearray(iface.RX_PACKET_LEN)
|
report = bytearray(iface.RX_PACKET_LEN)
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|||||||
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
if magic1 != _REP_MAGIC or magic2 != _REP_MAGIC:
|
||||||
raise CodecError("Invalid magic")
|
raise CodecError("Invalid magic")
|
||||||
|
|
||||||
|
buffer = buffer_getter() # will throw if other session is in progress
|
||||||
read_and_throw_away = False
|
read_and_throw_away = False
|
||||||
|
|
||||||
if msize > len(buffer):
|
if msize > len(buffer):
|
||||||
|
Loading…
Reference in New Issue
Block a user