mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-08 05:32:39 +00:00
refactor(core): replace loop.chan with loop.mailbox
[no changelog]
This commit is contained in:
parent
22be48cca4
commit
c241adfc4d
@ -76,7 +76,7 @@ class PairingContext(Context):
|
|||||||
def __init__(self, channel_ctx: Channel) -> None:
|
def __init__(self, channel_ctx: Channel) -> None:
|
||||||
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
|
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
|
||||||
self.channel_ctx: Channel = channel_ctx
|
self.channel_ctx: Channel = channel_ctx
|
||||||
self.incoming_message = loop.chan()
|
self.incoming_message = loop.mailbox()
|
||||||
self.secret: bytes = random.bytes(16)
|
self.secret: bytes = random.bytes(16)
|
||||||
|
|
||||||
self.display_data: PairingDisplayData = PairingDisplayData()
|
self.display_data: PairingDisplayData = PairingDisplayData()
|
||||||
@ -91,7 +91,6 @@ class PairingContext(Context):
|
|||||||
|
|
||||||
# apps.debug.DEBUG_CONTEXT = self
|
# apps.debug.DEBUG_CONTEXT = self
|
||||||
|
|
||||||
take = self.incoming_message.take()
|
|
||||||
next_message: Message | None = None
|
next_message: Message | None = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@ -100,7 +99,7 @@ class PairingContext(Context):
|
|||||||
# If the previous run did not keep an unprocessed message for us,
|
# If the previous run did not keep an unprocessed message for us,
|
||||||
# wait for a new one.
|
# wait for a new one.
|
||||||
try:
|
try:
|
||||||
message: Message = await take
|
message: Message = await self.incoming_message
|
||||||
except protocol_common.WireError as e:
|
except protocol_common.WireError as e:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, e)
|
log.exception(__name__, e)
|
||||||
@ -152,7 +151,7 @@ class PairingContext(Context):
|
|||||||
exp_type,
|
exp_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
message: Message = await self.incoming_message.take()
|
message: Message = await self.incoming_message
|
||||||
|
|
||||||
if message.type not in expected_types:
|
if message.type not in expected_types:
|
||||||
raise UnexpectedMessageException(message)
|
raise UnexpectedMessageException(message)
|
||||||
|
@ -363,7 +363,7 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -
|
|||||||
s = ctx.sessions[session_id]
|
s = ctx.sessions[session_id]
|
||||||
update_session_last_used(s.channel_id, s.session_id)
|
update_session_last_used(s.channel_id, s.session_id)
|
||||||
|
|
||||||
s.incoming_message.publish(
|
s.incoming_message.put(
|
||||||
Message(
|
Message(
|
||||||
message_type,
|
message_type,
|
||||||
ctx.buffer[
|
ctx.buffer[
|
||||||
@ -389,7 +389,7 @@ async def _handle_pairing(ctx: Channel, message_length: int) -> None:
|
|||||||
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
|
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
ctx.connection_context.incoming_message.publish(
|
ctx.connection_context.incoming_message.put(
|
||||||
Message(
|
Message(
|
||||||
message_type,
|
message_type,
|
||||||
ctx.buffer[
|
ctx.buffer[
|
||||||
|
@ -32,19 +32,18 @@ class GenericSessionContext(Context):
|
|||||||
super().__init__(channel.iface, channel.channel_id)
|
super().__init__(channel.iface, channel.channel_id)
|
||||||
self.channel: Channel = channel
|
self.channel: Channel = channel
|
||||||
self.session_id: int = session_id
|
self.session_id: int = session_id
|
||||||
self.incoming_message = loop.chan()
|
self.incoming_message = loop.mailbox()
|
||||||
self.handler_finder: HandlerFinder = find_handler
|
self.handler_finder: HandlerFinder = find_handler
|
||||||
|
|
||||||
async def handle(self) -> None:
|
async def handle(self) -> None:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
self._handle_debug()
|
self._handle_debug()
|
||||||
|
|
||||||
take = self.incoming_message.take()
|
|
||||||
next_message: Message | None = None
|
next_message: Message | None = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if await self._handle_message(take, next_message):
|
if await self._handle_message(next_message):
|
||||||
loop.schedule(self.handle())
|
loop.schedule(self.handle())
|
||||||
return
|
return
|
||||||
except UnexpectedMessageException as unexpected:
|
except UnexpectedMessageException as unexpected:
|
||||||
@ -71,12 +70,11 @@ class GenericSessionContext(Context):
|
|||||||
|
|
||||||
async def _handle_message(
|
async def _handle_message(
|
||||||
self,
|
self,
|
||||||
take: Awaitable[Any],
|
|
||||||
next_message: Message | None,
|
next_message: Message | None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message = await self._get_message(take, next_message)
|
message = await self._get_message(self.incoming_message, next_message)
|
||||||
except protocol_common.WireError as e:
|
except protocol_common.WireError as e:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.exception(__name__, e)
|
log.exception(__name__, e)
|
||||||
@ -136,7 +134,7 @@ class GenericSessionContext(Context):
|
|||||||
str(expected_types),
|
str(expected_types),
|
||||||
exp_type,
|
exp_type,
|
||||||
)
|
)
|
||||||
message: Message = await self.incoming_message.take()
|
message: Message = await self.incoming_message
|
||||||
if message.type not in expected_types:
|
if message.type not in expected_types:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
log.debug(
|
log.debug(
|
||||||
|
Loading…
Reference in New Issue
Block a user