1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-08 05:32:39 +00:00

refactor(core): replace loop.chan with loop.mailbox

[no changelog]
This commit is contained in:
M1nd3r 2024-10-18 11:41:02 +02:00
parent 22be48cca4
commit c241adfc4d
3 changed files with 9 additions and 12 deletions

View File

@ -76,7 +76,7 @@ class PairingContext(Context):
def __init__(self, channel_ctx: Channel) -> None: def __init__(self, channel_ctx: Channel) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id) super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx: Channel = channel_ctx self.channel_ctx: Channel = channel_ctx
self.incoming_message = loop.chan() self.incoming_message = loop.mailbox()
self.secret: bytes = random.bytes(16) self.secret: bytes = random.bytes(16)
self.display_data: PairingDisplayData = PairingDisplayData() self.display_data: PairingDisplayData = PairingDisplayData()
@ -91,7 +91,6 @@ class PairingContext(Context):
# apps.debug.DEBUG_CONTEXT = self # apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take()
next_message: Message | None = None next_message: Message | None = None
while True: while True:
@ -100,7 +99,7 @@ class PairingContext(Context):
# If the previous run did not keep an unprocessed message for us, # If the previous run did not keep an unprocessed message for us,
# wait for a new one. # wait for a new one.
try: try:
message: Message = await take message: Message = await self.incoming_message
except protocol_common.WireError as e: except protocol_common.WireError as e:
if __debug__: if __debug__:
log.exception(__name__, e) log.exception(__name__, e)
@ -152,7 +151,7 @@ class PairingContext(Context):
exp_type, exp_type,
) )
message: Message = await self.incoming_message.take() message: Message = await self.incoming_message
if message.type not in expected_types: if message.type not in expected_types:
raise UnexpectedMessageException(message) raise UnexpectedMessageException(message)

View File

@ -363,7 +363,7 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -
s = ctx.sessions[session_id] s = ctx.sessions[session_id]
update_session_last_used(s.channel_id, s.session_id) update_session_last_used(s.channel_id, s.session_id)
s.incoming_message.publish( s.incoming_message.put(
Message( Message(
message_type, message_type,
ctx.buffer[ ctx.buffer[
@ -389,7 +389,7 @@ async def _handle_pairing(ctx: Channel, message_length: int) -> None:
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :] ">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
)[0] )[0]
ctx.connection_context.incoming_message.publish( ctx.connection_context.incoming_message.put(
Message( Message(
message_type, message_type,
ctx.buffer[ ctx.buffer[

View File

@ -32,19 +32,18 @@ class GenericSessionContext(Context):
super().__init__(channel.iface, channel.channel_id) super().__init__(channel.iface, channel.channel_id)
self.channel: Channel = channel self.channel: Channel = channel
self.session_id: int = session_id self.session_id: int = session_id
self.incoming_message = loop.chan() self.incoming_message = loop.mailbox()
self.handler_finder: HandlerFinder = find_handler self.handler_finder: HandlerFinder = find_handler
async def handle(self) -> None: async def handle(self) -> None:
if __debug__: if __debug__:
self._handle_debug() self._handle_debug()
take = self.incoming_message.take()
next_message: Message | None = None next_message: Message | None = None
while True: while True:
try: try:
if await self._handle_message(take, next_message): if await self._handle_message(next_message):
loop.schedule(self.handle()) loop.schedule(self.handle())
return return
except UnexpectedMessageException as unexpected: except UnexpectedMessageException as unexpected:
@ -71,12 +70,11 @@ class GenericSessionContext(Context):
async def _handle_message( async def _handle_message(
self, self,
take: Awaitable[Any],
next_message: Message | None, next_message: Message | None,
) -> bool: ) -> bool:
try: try:
message = await self._get_message(take, next_message) message = await self._get_message(self.incoming_message, next_message)
except protocol_common.WireError as e: except protocol_common.WireError as e:
if __debug__: if __debug__:
log.exception(__name__, e) log.exception(__name__, e)
@ -136,7 +134,7 @@ class GenericSessionContext(Context):
str(expected_types), str(expected_types),
exp_type, exp_type,
) )
message: Message = await self.incoming_message.take() message: Message = await self.incoming_message
if message.type not in expected_types: if message.type not in expected_types:
if __debug__: if __debug__:
log.debug( log.debug(