mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-19 20:01:11 +00:00
feat(core): keep track of current context for workflow
This commit is contained in:
parent
78a8b48f1e
commit
fe80793b47
@ -193,6 +193,8 @@ trezor.wire
|
|||||||
import trezor.wire
|
import trezor.wire
|
||||||
trezor.wire.codec_v1
|
trezor.wire.codec_v1
|
||||||
import trezor.wire.codec_v1
|
import trezor.wire.codec_v1
|
||||||
|
trezor.wire.context
|
||||||
|
import trezor.wire.context
|
||||||
trezor.wire.errors
|
trezor.wire.errors
|
||||||
import trezor.wire.errors
|
import trezor.wire.errors
|
||||||
trezor.workflow
|
trezor.workflow
|
||||||
|
@ -42,7 +42,7 @@ from storage.cache import InvalidSessionError
|
|||||||
from trezor import log, loop, protobuf, utils, workflow
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
from trezor.enums import FailureType
|
from trezor.enums import FailureType
|
||||||
from trezor.messages import Failure
|
from trezor.messages import Failure
|
||||||
from trezor.wire import codec_v1
|
from trezor.wire import codec_v1, context
|
||||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||||
|
|
||||||
# Import all errors into namespace, so that `wire.Error` is available from
|
# Import all errors into namespace, so that `wire.Error` is available from
|
||||||
@ -53,43 +53,22 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Container,
|
Container,
|
||||||
Coroutine,
|
Coroutine,
|
||||||
Iterable,
|
|
||||||
Protocol,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
|
|
||||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||||
Handler = Callable[["Context", Msg], HandlerTask]
|
Handler = Callable[[Msg], HandlerTask]
|
||||||
|
|
||||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
class GenericContext(Protocol):
|
|
||||||
async def call(
|
|
||||||
self,
|
|
||||||
msg: protobuf.MessageType,
|
|
||||||
expected_type: type[protobuf.MessageType],
|
|
||||||
) -> Any:
|
|
||||||
...
|
|
||||||
|
|
||||||
async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
|
|
||||||
...
|
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
# XXX modify type signature so that the return value must be of the same type?
|
|
||||||
async def wait(self, *tasks: Awaitable) -> Any:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||||
experimental_enabled = False
|
EXPERIMENTAL_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||||
@ -97,12 +76,12 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
|||||||
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
|
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
|
||||||
|
|
||||||
|
|
||||||
def _wrap_protobuf_load(
|
def wrap_protobuf_load(
|
||||||
buffer: bytes,
|
buffer: bytes,
|
||||||
expected_type: type[LoadedMessageType],
|
expected_type: type[LoadedMessageType],
|
||||||
) -> LoadedMessageType:
|
) -> LoadedMessageType:
|
||||||
try:
|
try:
|
||||||
msg = protobuf.decode(buffer, expected_type, experimental_enabled)
|
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||||
if __debug__ and utils.EMULATOR:
|
if __debug__ and utils.EMULATOR:
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||||
@ -117,22 +96,6 @@ def _wrap_protobuf_load(
|
|||||||
raise DataError("Failed to decode message")
|
raise DataError("Failed to decode message")
|
||||||
|
|
||||||
|
|
||||||
class DummyContext:
|
|
||||||
async def call(self, *argv: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def read(self, *argv: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def write(self, *argv: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def wait(self, *tasks: Awaitable) -> Any:
|
|
||||||
return await loop.race(*tasks)
|
|
||||||
|
|
||||||
|
|
||||||
DUMMY_CONTEXT = DummyContext()
|
|
||||||
|
|
||||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||||
|
|
||||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
@ -142,143 +105,8 @@ if __debug__:
|
|||||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
|
||||||
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
|
||||||
self.iface = iface
|
|
||||||
self.sid = sid
|
|
||||||
self.buffer = buffer
|
|
||||||
|
|
||||||
async def call(
|
|
||||||
self,
|
|
||||||
msg: protobuf.MessageType,
|
|
||||||
expected_type: type[LoadedMessageType],
|
|
||||||
) -> LoadedMessageType:
|
|
||||||
await self.write(msg)
|
|
||||||
del msg
|
|
||||||
return await self.read(expected_type)
|
|
||||||
|
|
||||||
async def call_any(
|
|
||||||
self, msg: protobuf.MessageType, *expected_wire_types: int
|
|
||||||
) -> protobuf.MessageType:
|
|
||||||
await self.write(msg)
|
|
||||||
del msg
|
|
||||||
return await self.read_any(expected_wire_types)
|
|
||||||
|
|
||||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
|
||||||
return codec_v1.read_message(self.iface, self.buffer)
|
|
||||||
|
|
||||||
async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%s:%x expect: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
self.sid,
|
|
||||||
expected_type.MESSAGE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the full message into a buffer, parse out type and data payload
|
|
||||||
msg = await self.read_from_wire()
|
|
||||||
|
|
||||||
# If we got a message with unexpected type, raise the message via
|
|
||||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
|
||||||
if msg.type != expected_type.MESSAGE_WIRE_TYPE:
|
|
||||||
raise UnexpectedMessageError(msg)
|
|
||||||
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%s:%x read: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
self.sid,
|
|
||||||
expected_type.MESSAGE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# look up the protobuf class and parse the message
|
|
||||||
return _wrap_protobuf_load(msg.data, expected_type)
|
|
||||||
|
|
||||||
async def read_any(
|
|
||||||
self, expected_wire_types: Iterable[int]
|
|
||||||
) -> protobuf.MessageType:
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%s:%x expect: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
self.sid,
|
|
||||||
expected_wire_types,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load the full message into a buffer, parse out type and data payload
|
|
||||||
msg = await self.read_from_wire()
|
|
||||||
|
|
||||||
# If we got a message with unexpected type, raise the message via
|
|
||||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
|
||||||
if msg.type not in expected_wire_types:
|
|
||||||
raise UnexpectedMessageError(msg)
|
|
||||||
|
|
||||||
# find the protobuf type
|
|
||||||
exptype = protobuf.type_for_wire(msg.type)
|
|
||||||
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%s:%x read: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
self.sid,
|
|
||||||
exptype.MESSAGE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# parse the message and return it
|
|
||||||
return _wrap_protobuf_load(msg.data, exptype)
|
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
|
||||||
if __debug__:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"%s:%x write: %s",
|
|
||||||
self.iface.iface_num(),
|
|
||||||
self.sid,
|
|
||||||
msg.MESSAGE_NAME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# cannot write message without wire type
|
|
||||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
|
||||||
|
|
||||||
msg_size = protobuf.encoded_length(msg)
|
|
||||||
|
|
||||||
if msg_size <= len(self.buffer):
|
|
||||||
# reuse preallocated
|
|
||||||
buffer = self.buffer
|
|
||||||
else:
|
|
||||||
# message is too big, we need to allocate a new buffer
|
|
||||||
buffer = bytearray(msg_size)
|
|
||||||
|
|
||||||
msg_size = protobuf.encode(buffer, msg)
|
|
||||||
|
|
||||||
await codec_v1.write_message(
|
|
||||||
self.iface,
|
|
||||||
msg.MESSAGE_WIRE_TYPE,
|
|
||||||
memoryview(buffer)[:msg_size],
|
|
||||||
)
|
|
||||||
|
|
||||||
def wait(self, *tasks: Awaitable) -> Any:
|
|
||||||
"""
|
|
||||||
Wait until one of the passed tasks finishes, and return the result,
|
|
||||||
while servicing the wire context. If a message comes until one of the
|
|
||||||
tasks ends, `UnexpectedMessageError` is raised.
|
|
||||||
"""
|
|
||||||
return loop.race(self.read_any(()), *tasks)
|
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedMessageError(Exception):
|
|
||||||
def __init__(self, msg: codec_v1.Message) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.msg = msg
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_single_message(
|
async def _handle_single_message(
|
||||||
ctx: Context, msg: codec_v1.Message, use_workflow: bool
|
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool
|
||||||
) -> codec_v1.Message | None:
|
) -> codec_v1.Message | None:
|
||||||
"""Handle a message that was loaded from USB by the caller.
|
"""Handle a message that was loaded from USB by the caller.
|
||||||
|
|
||||||
@ -288,10 +116,10 @@ async def _handle_single_message(
|
|||||||
If the workflow finished normally or with an error, the return value is None.
|
If the workflow finished normally or with an error, the return value is None.
|
||||||
|
|
||||||
If an unexpected message had arrived on the wire while the workflow was processing,
|
If an unexpected message had arrived on the wire while the workflow was processing,
|
||||||
the workflow is shut down with an `UnexpectedMessageError`. This is not considered
|
the workflow is shut down with an `UnexpectedMessage` exception. This is not
|
||||||
an "error condition" to return over the wire -- instead the message is processed
|
considered an "error condition" to return over the wire -- instead the message
|
||||||
as if starting a new workflow.
|
is processed as if starting a new workflow.
|
||||||
In such case, the `UnexpectedMessageError` is caught and the message is returned
|
In such case, the `UnexpectedMessage` is caught and the message is returned
|
||||||
to the caller. It will then be processed in the next iteration of the message loop.
|
to the caller. It will then be processed in the next iteration of the message loop.
|
||||||
"""
|
"""
|
||||||
if __debug__:
|
if __debug__:
|
||||||
@ -330,10 +158,10 @@ async def _handle_single_message(
|
|||||||
|
|
||||||
# Try to decode the message according to schema from
|
# Try to decode the message according to schema from
|
||||||
# `req_type`. Raises if the message is malformed.
|
# `req_type`. Raises if the message is malformed.
|
||||||
req_msg = _wrap_protobuf_load(msg.data, req_type)
|
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||||
|
|
||||||
# Create the handler task.
|
# Create the handler task.
|
||||||
task = handler(ctx, req_msg)
|
task = handler(req_msg)
|
||||||
|
|
||||||
# Run the workflow task. Workflow can do more on-the-wire
|
# Run the workflow task. Workflow can do more on-the-wire
|
||||||
# communication inside, but it should eventually return a
|
# communication inside, but it should eventually return a
|
||||||
@ -342,17 +170,17 @@ async def _handle_single_message(
|
|||||||
if use_workflow:
|
if use_workflow:
|
||||||
# Spawn a workflow around the task. This ensures that concurrent
|
# Spawn a workflow around the task. This ensures that concurrent
|
||||||
# workflows are shut down.
|
# workflows are shut down.
|
||||||
res_msg = await workflow.spawn(task)
|
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||||
else:
|
else:
|
||||||
# For debug messages, ignore workflow processing and just await
|
# For debug messages, ignore workflow processing and just await
|
||||||
# results of the handler.
|
# results of the handler.
|
||||||
res_msg = await task
|
res_msg = await task
|
||||||
|
|
||||||
except UnexpectedMessageError as exc:
|
except context.UnexpectedMessage as exc:
|
||||||
# Workflow was trying to read a message from the wire, and
|
# Workflow was trying to read a message from the wire, and
|
||||||
# something unexpected came in. See Context.read() for
|
# something unexpected came in. See Context.read() for
|
||||||
# example, which expects some particular message and raises
|
# example, which expects some particular message and raises
|
||||||
# UnexpectedMessageError if another one comes in.
|
# UnexpectedMessage if another one comes in.
|
||||||
# In order not to lose the message, we return it to the caller.
|
# In order not to lose the message, we return it to the caller.
|
||||||
# TODO:
|
# TODO:
|
||||||
# We might handle only the few common cases here, like
|
# We might handle only the few common cases here, like
|
||||||
@ -390,7 +218,7 @@ async def handle_session(
|
|||||||
else:
|
else:
|
||||||
ctx_buffer = WIRE_BUFFER
|
ctx_buffer = WIRE_BUFFER
|
||||||
|
|
||||||
ctx = Context(iface, session_id, ctx_buffer)
|
ctx = context.Context(iface, session_id, ctx_buffer)
|
||||||
next_msg: codec_v1.Message | None = None
|
next_msg: codec_v1.Message | None = None
|
||||||
|
|
||||||
if __debug__ and is_debug_session:
|
if __debug__ and is_debug_session:
|
||||||
|
266
core/src/trezor/wire/context.py
Normal file
266
core/src/trezor/wire/context.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
"""Context pseudo-global.
|
||||||
|
|
||||||
|
Each workflow handler runs in a "context" which is tied to a particular communication
|
||||||
|
session. When the handler needs to communicate with the host, it needs access to that
|
||||||
|
context.
|
||||||
|
|
||||||
|
To avoid the need to pass a context object around, the context is stored in a
|
||||||
|
pseudo-global manner: any workflow handler can request access to the context via this
|
||||||
|
module, and the appropriate context object will be used for it.
|
||||||
|
|
||||||
|
Some workflows don't need a context to exist. This is supported by the `maybe_call`
|
||||||
|
function, which will silently ignore the call if no context is available. Useful mainly
|
||||||
|
for ButtonRequests. Of course, `context.wait()` transparently works in such situations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from trezor import log, loop, protobuf
|
||||||
|
|
||||||
|
from . import codec_v1
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
Container,
|
||||||
|
Coroutine,
|
||||||
|
Generator,
|
||||||
|
TypeVar,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
from trezorio import WireInterface
|
||||||
|
|
||||||
|
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||||
|
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||||
|
Handler = Callable[["Context", Msg], HandlerTask]
|
||||||
|
|
||||||
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedMessage(Exception):
|
||||||
|
"""A message was received that is not part of the current workflow.
|
||||||
|
|
||||||
|
Utility exception to inform the session handler that the current workflow
|
||||||
|
should be aborted and a new one started as if `msg` was the first message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, msg: codec_v1.Message) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
|
class Context:
|
||||||
|
"""Wire context.
|
||||||
|
|
||||||
|
Represents USB communication inside a particular session on a particular interface
|
||||||
|
(i.e., wire, debug, single BT connection, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
||||||
|
self.iface = iface
|
||||||
|
self.sid = sid
|
||||||
|
self.buffer = buffer
|
||||||
|
|
||||||
|
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||||
|
"""Read a whole message from the wire without parsing it."""
|
||||||
|
return codec_v1.read_message(self.iface, self.buffer)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def read(self, expected_types: Container[int]) -> protobuf.MessageType:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def read(
|
||||||
|
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||||
|
) -> LoadedMessageType:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def read(
|
||||||
|
self,
|
||||||
|
expected_types: Container[int],
|
||||||
|
expected_type: type[protobuf.MessageType] | None = None,
|
||||||
|
) -> protobuf.MessageType:
|
||||||
|
"""Read a message from the wire.
|
||||||
|
|
||||||
|
The read message must be of one of the types specified in `expected_types`.
|
||||||
|
If only a single type is expected, it can be passed as `expected_type`,
|
||||||
|
to save on having to decode the type code into a protobuf class.
|
||||||
|
"""
|
||||||
|
if __debug__:
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%s:%x expect: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
self.sid,
|
||||||
|
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the full message into a buffer, parse out type and data payload
|
||||||
|
msg = await self.read_from_wire()
|
||||||
|
|
||||||
|
# If we got a message with unexpected type, raise the message via
|
||||||
|
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||||
|
if msg.type not in expected_types:
|
||||||
|
raise UnexpectedMessage(msg)
|
||||||
|
|
||||||
|
if expected_type is None:
|
||||||
|
expected_type = protobuf.type_for_wire(msg.type)
|
||||||
|
|
||||||
|
if __debug__:
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%s:%x read: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
self.sid,
|
||||||
|
expected_type.MESSAGE_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
# look up the protobuf class and parse the message
|
||||||
|
from . import wrap_protobuf_load
|
||||||
|
|
||||||
|
return wrap_protobuf_load(msg.data, expected_type)
|
||||||
|
|
||||||
|
async def write(self, msg: protobuf.MessageType) -> None:
|
||||||
|
"""Write a message to the wire."""
|
||||||
|
if __debug__:
|
||||||
|
log.debug(
|
||||||
|
__name__,
|
||||||
|
"%s:%x write: %s",
|
||||||
|
self.iface.iface_num(),
|
||||||
|
self.sid,
|
||||||
|
msg.MESSAGE_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
# cannot write message without wire type
|
||||||
|
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||||
|
|
||||||
|
msg_size = protobuf.encoded_length(msg)
|
||||||
|
|
||||||
|
if msg_size <= len(self.buffer):
|
||||||
|
# reuse preallocated
|
||||||
|
buffer = self.buffer
|
||||||
|
else:
|
||||||
|
# message is too big, we need to allocate a new buffer
|
||||||
|
buffer = bytearray(msg_size)
|
||||||
|
|
||||||
|
msg_size = protobuf.encode(buffer, msg)
|
||||||
|
|
||||||
|
await codec_v1.write_message(
|
||||||
|
self.iface,
|
||||||
|
msg.MESSAGE_WIRE_TYPE,
|
||||||
|
memoryview(buffer)[:msg_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CURRENT_CONTEXT: Context | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def wait(*tasks: Awaitable) -> Any:
|
||||||
|
"""
|
||||||
|
Wait until one of the passed tasks finishes, and return the result, while servicing
|
||||||
|
the wire context.
|
||||||
|
|
||||||
|
Used to make sure the device is responsive on USB while waiting for user
|
||||||
|
interaction. If a message is received before any of the passed in tasks finish, it
|
||||||
|
raises an `UnexpectedMessage` exception, returning control to the session handler.
|
||||||
|
"""
|
||||||
|
if CURRENT_CONTEXT is None:
|
||||||
|
return loop.race(*tasks)
|
||||||
|
else:
|
||||||
|
return loop.race(CURRENT_CONTEXT.read(()), *tasks)
|
||||||
|
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
msg: protobuf.MessageType,
|
||||||
|
expected_type: type[LoadedMessageType],
|
||||||
|
) -> LoadedMessageType:
|
||||||
|
"""Send a message to the host and wait for a response of a particular type.
|
||||||
|
|
||||||
|
Raises if there is no context for this workflow."""
|
||||||
|
if CURRENT_CONTEXT is None:
|
||||||
|
raise RuntimeError("No wire context")
|
||||||
|
|
||||||
|
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||||
|
|
||||||
|
await CURRENT_CONTEXT.write(msg)
|
||||||
|
del msg
|
||||||
|
return await CURRENT_CONTEXT.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def call_any(
|
||||||
|
msg: protobuf.MessageType, *expected_wire_types: int
|
||||||
|
) -> protobuf.MessageType:
|
||||||
|
"""Send a message to the host and wait for a response.
|
||||||
|
|
||||||
|
The response can be of any of the types specified in `expected_wire_types`.
|
||||||
|
|
||||||
|
Raises if there is no context for this workflow."""
|
||||||
|
if CURRENT_CONTEXT is None:
|
||||||
|
raise RuntimeError("No wire context")
|
||||||
|
|
||||||
|
await CURRENT_CONTEXT.write(msg)
|
||||||
|
del msg
|
||||||
|
return await CURRENT_CONTEXT.read(expected_wire_types)
|
||||||
|
|
||||||
|
|
||||||
|
async def maybe_call(
|
||||||
|
msg: protobuf.MessageType, expected_type: type[LoadedMessageType]
|
||||||
|
) -> None:
|
||||||
|
"""Send a message to the host and read but ignore the response.
|
||||||
|
|
||||||
|
If there is a context, the function still checks that the response is of the
|
||||||
|
requested type. If there is no context, the call is ignored.
|
||||||
|
"""
|
||||||
|
if CURRENT_CONTEXT is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
await call(msg, expected_type)
|
||||||
|
|
||||||
|
|
||||||
|
def get_context() -> Context:
|
||||||
|
"""Get the current session context.
|
||||||
|
|
||||||
|
Can be needed in case the caller needs raw read and raw write capabilities, which
|
||||||
|
are not provided by the module functions.
|
||||||
|
|
||||||
|
Result of this function should not be stored -- the context is technically allowed
|
||||||
|
to change inbetween any `await` statements.
|
||||||
|
"""
|
||||||
|
if CURRENT_CONTEXT is None:
|
||||||
|
raise RuntimeError("No wire context")
|
||||||
|
return CURRENT_CONTEXT
|
||||||
|
|
||||||
|
|
||||||
|
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
||||||
|
"""Run a workflow in a particular context.
|
||||||
|
|
||||||
|
Stores the context in a closure and installs it into the global variable every time
|
||||||
|
the closure is resumed, thus making sure that all calls to `wire.context.*` will
|
||||||
|
work as expected.
|
||||||
|
"""
|
||||||
|
global CURRENT_CONTEXT
|
||||||
|
send_val = None
|
||||||
|
send_exc = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
CURRENT_CONTEXT = ctx
|
||||||
|
try:
|
||||||
|
if send_exc is not None:
|
||||||
|
res = workflow.throw(send_exc)
|
||||||
|
else:
|
||||||
|
res = workflow.send(send_val)
|
||||||
|
except StopIteration as st:
|
||||||
|
return st.value
|
||||||
|
finally:
|
||||||
|
CURRENT_CONTEXT = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
send_val = yield res
|
||||||
|
except BaseException as e:
|
||||||
|
send_exc = e
|
||||||
|
else:
|
||||||
|
send_exc = None
|
Loading…
Reference in New Issue
Block a user