1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-01 10:20:59 +00:00

refactor(core): replace loop.chan with loop.mailbox

[no changelog]
This commit is contained in:
M1nd3r 2024-10-18 11:41:02 +02:00
parent 22be48cca4
commit c241adfc4d
3 changed files with 9 additions and 12 deletions

View File

@ -76,7 +76,7 @@ class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx: Channel = channel_ctx
self.incoming_message = loop.chan()
self.incoming_message = loop.mailbox()
self.secret: bytes = random.bytes(16)
self.display_data: PairingDisplayData = PairingDisplayData()
@ -91,7 +91,6 @@ class PairingContext(Context):
# apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take()
next_message: Message | None = None
while True:
@ -100,7 +99,7 @@ class PairingContext(Context):
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: Message = await take
message: Message = await self.incoming_message
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
@ -152,7 +151,7 @@ class PairingContext(Context):
exp_type,
)
message: Message = await self.incoming_message.take()
message: Message = await self.incoming_message
if message.type not in expected_types:
raise UnexpectedMessageException(message)

View File

@ -363,7 +363,7 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -
s = ctx.sessions[session_id]
update_session_last_used(s.channel_id, s.session_id)
s.incoming_message.publish(
s.incoming_message.put(
Message(
message_type,
ctx.buffer[
@ -389,7 +389,7 @@ async def _handle_pairing(ctx: Channel, message_length: int) -> None:
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.publish(
ctx.connection_context.incoming_message.put(
Message(
message_type,
ctx.buffer[

View File

@ -32,19 +32,18 @@ class GenericSessionContext(Context):
super().__init__(channel.iface, channel.channel_id)
self.channel: Channel = channel
self.session_id: int = session_id
self.incoming_message = loop.chan()
self.incoming_message = loop.mailbox()
self.handler_finder: HandlerFinder = find_handler
async def handle(self) -> None:
if __debug__:
self._handle_debug()
take = self.incoming_message.take()
next_message: Message | None = None
while True:
try:
if await self._handle_message(take, next_message):
if await self._handle_message(next_message):
loop.schedule(self.handle())
return
except UnexpectedMessageException as unexpected:
@ -71,12 +70,11 @@ class GenericSessionContext(Context):
async def _handle_message(
self,
take: Awaitable[Any],
next_message: Message | None,
) -> bool:
try:
message = await self._get_message(take, next_message)
message = await self._get_message(self.incoming_message, next_message)
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
@ -136,7 +134,7 @@ class GenericSessionContext(Context):
str(expected_types),
exp_type,
)
message: Message = await self.incoming_message.take()
message: Message = await self.incoming_message
if message.type not in expected_types:
if __debug__:
log.debug(