mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-01 10:20:59 +00:00
fix(core): fix pairing context after separation of THP messages
[no changelog]
This commit is contained in:
parent
24d43ab2e5
commit
00d7e42939
@ -277,6 +277,7 @@ async def handle_EndSession(msg: EndSession) -> Success:
|
||||
end_current_session()
|
||||
return Success()
|
||||
|
||||
|
||||
async def handle_Ping(msg: Ping) -> Success:
|
||||
if msg.button_protection:
|
||||
from trezor.enums import ButtonRequestType as B
|
||||
|
@ -69,6 +69,12 @@ if utils.USE_THP:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_msg_type(msg_name: str) -> int | None:
|
||||
value = getattr(ThpMessageType, msg_name)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def handle_single_message(
|
||||
ctx: Context,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
|
||||
from trezor import log, protobuf, utils
|
||||
from trezor.enums import ThpMessageType
|
||||
from trezor.wire.message_handler import get_msg_type
|
||||
|
||||
from . import ChannelState, ThpError
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
@ -53,13 +53,6 @@ def get_write_buffer(
|
||||
return buffer
|
||||
|
||||
|
||||
def get_msg_type(msg_name: str) -> int | None:
|
||||
value = getattr(ThpMessageType, msg_name)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def encode_into_buffer(
|
||||
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
|
||||
) -> int:
|
||||
|
@ -158,7 +158,11 @@ class PairingContext(Context):
|
||||
raise UnexpectedMessageException(message)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
name = message_handler.get_msg_name(message.type)
|
||||
if name is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
else:
|
||||
expected_type = protobuf.type_for_name(name)
|
||||
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
@ -168,12 +172,16 @@ class PairingContext(Context):
|
||||
async def call(
|
||||
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
||||
) -> protobuf.MessageType:
|
||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||
expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME)
|
||||
if expected_wire_type is None:
|
||||
expected_wire_type = expected_type.MESSAGE_WIRE_TYPE
|
||||
|
||||
assert expected_wire_type is not None
|
||||
|
||||
await self.write(msg)
|
||||
del msg
|
||||
|
||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||
return await self.read((expected_wire_type,), expected_type)
|
||||
|
||||
async def call_any(
|
||||
self, msg: protobuf.MessageType, *expected_types: int
|
||||
|
Loading…
Reference in New Issue
Block a user