1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-01 10:20:59 +00:00

fix(core): fix pairing context after separation of THP messages

[no changelog]
This commit is contained in:
M1nd3r 2024-10-17 11:43:51 +02:00
parent 24d43ab2e5
commit 00d7e42939
4 changed files with 19 additions and 11 deletions

View File

@ -277,6 +277,7 @@ async def handle_EndSession(msg: EndSession) -> Success:
end_current_session()
return Success()
async def handle_Ping(msg: Ping) -> Success:
if msg.button_protection:
from trezor.enums import ButtonRequestType as B

View File

@ -69,6 +69,12 @@ if utils.USE_THP:
return name
return None
def get_msg_type(msg_name: str) -> int | None:
value = getattr(ThpMessageType, msg_name)
if isinstance(value, int):
return value
return None
async def handle_single_message(
ctx: Context,

View File

@ -1,6 +1,6 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, protobuf, utils
from trezor.enums import ThpMessageType
from trezor.wire.message_handler import get_msg_type
from . import ChannelState, ThpError
from .checksum import CHECKSUM_LENGTH
@ -53,13 +53,6 @@ def get_write_buffer(
return buffer
def get_msg_type(msg_name: str) -> int | None:
value = getattr(ThpMessageType, msg_name)
if isinstance(value, int):
return value
return None
def encode_into_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:

View File

@ -158,7 +158,11 @@ class PairingContext(Context):
raise UnexpectedMessageException(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
name = message_handler.get_msg_name(message.type)
if name is None:
expected_type = protobuf.type_for_wire(message.type)
else:
expected_type = protobuf.type_for_name(name)
return message_handler.wrap_protobuf_load(message.data, expected_type)
@ -168,12 +172,16 @@ class PairingContext(Context):
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
) -> protobuf.MessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME)
if expected_wire_type is None:
expected_wire_type = expected_type.MESSAGE_WIRE_TYPE
assert expected_wire_type is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
return await self.read((expected_wire_type,), expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_types: int