From 00d7e4293953c38a94454dfd7a33c644a9a9e294 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Thu, 17 Oct 2024 11:43:51 +0200 Subject: [PATCH] fix(core): fix pairing context after separation of THP messages [no changelog] --- core/src/apps/base.py | 1 + core/src/trezor/wire/message_handler.py | 6 ++++++ core/src/trezor/wire/thp/memory_manager.py | 9 +-------- core/src/trezor/wire/thp/pairing_context.py | 14 +++++++++++--- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 0dfc65bf4c..30ba67d0e3 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -277,6 +277,7 @@ async def handle_EndSession(msg: EndSession) -> Success: end_current_session() return Success() + async def handle_Ping(msg: Ping) -> Success: if msg.button_protection: from trezor.enums import ButtonRequestType as B diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index b2459cb3eb..33c08afc75 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -69,6 +69,12 @@ if utils.USE_THP: return name return None + def get_msg_type(msg_name: str) -> int | None: + value = getattr(ThpMessageType, msg_name) + if isinstance(value, int): + return value + return None + async def handle_single_message( ctx: Context, diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index 34c7f10e18..d9348a6137 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -1,6 +1,6 @@ from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH from trezor import log, protobuf, utils -from trezor.enums import ThpMessageType +from trezor.wire.message_handler import get_msg_type from . import ChannelState, ThpError from .checksum import CHECKSUM_LENGTH @@ -53,13 +53,6 @@ def get_write_buffer( return buffer -def get_msg_type(msg_name: str) -> int | None: - value = getattr(ThpMessageType, msg_name) - if isinstance(value, int): - return value - return None - - def encode_into_buffer( buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int ) -> int: diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 43613126f1..9892fd4dcc 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -158,7 +158,11 @@ class PairingContext(Context): raise UnexpectedMessageException(message) if expected_type is None: - expected_type = protobuf.type_for_wire(message.type) + name = message_handler.get_msg_name(message.type) + if name is None: + expected_type = protobuf.type_for_wire(message.type) + else: + expected_type = protobuf.type_for_name(name) return message_handler.wrap_protobuf_load(message.data, expected_type) @@ -168,12 +172,16 @@ class PairingContext(Context): async def call( self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] ) -> protobuf.MessageType: - assert expected_type.MESSAGE_WIRE_TYPE is not None + expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME) + if expected_wire_type is None: + expected_wire_type = expected_type.MESSAGE_WIRE_TYPE + + assert expected_wire_type is not None await self.write(msg) del msg - return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + return await self.read((expected_wire_type,), expected_type) async def call_any( self, msg: protobuf.MessageType, *expected_types: int