mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-01 10:20:59 +00:00
fix(core): fix pairing context after separation of THP messages
[no changelog]
This commit is contained in:
parent
24d43ab2e5
commit
00d7e42939
@ -277,6 +277,7 @@ async def handle_EndSession(msg: EndSession) -> Success:
|
|||||||
end_current_session()
|
end_current_session()
|
||||||
return Success()
|
return Success()
|
||||||
|
|
||||||
|
|
||||||
async def handle_Ping(msg: Ping) -> Success:
|
async def handle_Ping(msg: Ping) -> Success:
|
||||||
if msg.button_protection:
|
if msg.button_protection:
|
||||||
from trezor.enums import ButtonRequestType as B
|
from trezor.enums import ButtonRequestType as B
|
||||||
|
@ -69,6 +69,12 @@ if utils.USE_THP:
|
|||||||
return name
|
return name
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_msg_type(msg_name: str) -> int | None:
|
||||||
|
value = getattr(ThpMessageType, msg_name)
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def handle_single_message(
|
async def handle_single_message(
|
||||||
ctx: Context,
|
ctx: Context,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
|
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
|
||||||
from trezor import log, protobuf, utils
|
from trezor import log, protobuf, utils
|
||||||
from trezor.enums import ThpMessageType
|
from trezor.wire.message_handler import get_msg_type
|
||||||
|
|
||||||
from . import ChannelState, ThpError
|
from . import ChannelState, ThpError
|
||||||
from .checksum import CHECKSUM_LENGTH
|
from .checksum import CHECKSUM_LENGTH
|
||||||
@ -53,13 +53,6 @@ def get_write_buffer(
|
|||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
def get_msg_type(msg_name: str) -> int | None:
|
|
||||||
value = getattr(ThpMessageType, msg_name)
|
|
||||||
if isinstance(value, int):
|
|
||||||
return value
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def encode_into_buffer(
|
def encode_into_buffer(
|
||||||
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
|
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
|
||||||
) -> int:
|
) -> int:
|
||||||
|
@ -158,7 +158,11 @@ class PairingContext(Context):
|
|||||||
raise UnexpectedMessageException(message)
|
raise UnexpectedMessageException(message)
|
||||||
|
|
||||||
if expected_type is None:
|
if expected_type is None:
|
||||||
expected_type = protobuf.type_for_wire(message.type)
|
name = message_handler.get_msg_name(message.type)
|
||||||
|
if name is None:
|
||||||
|
expected_type = protobuf.type_for_wire(message.type)
|
||||||
|
else:
|
||||||
|
expected_type = protobuf.type_for_name(name)
|
||||||
|
|
||||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||||
|
|
||||||
@ -168,12 +172,16 @@ class PairingContext(Context):
|
|||||||
async def call(
|
async def call(
|
||||||
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
||||||
) -> protobuf.MessageType:
|
) -> protobuf.MessageType:
|
||||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
expected_wire_type = message_handler.get_msg_type(expected_type.MESSAGE_NAME)
|
||||||
|
if expected_wire_type is None:
|
||||||
|
expected_wire_type = expected_type.MESSAGE_WIRE_TYPE
|
||||||
|
|
||||||
|
assert expected_wire_type is not None
|
||||||
|
|
||||||
await self.write(msg)
|
await self.write(msg)
|
||||||
del msg
|
del msg
|
||||||
|
|
||||||
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
return await self.read((expected_wire_type,), expected_type)
|
||||||
|
|
||||||
async def call_any(
|
async def call_any(
|
||||||
self, msg: protobuf.MessageType, *expected_types: int
|
self, msg: protobuf.MessageType, *expected_types: int
|
||||||
|
Loading…
Reference in New Issue
Block a user