diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 52b9ced5c..b3ac3f263 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -1,4 +1,4 @@ -from trezor import log +from trezor import log, protobuf from trezor.enums import ThpPairingMethod from trezor.messages import ( ThpCodeEntryChallenge, @@ -22,8 +22,9 @@ from trezor.wire.thp.thp_session import ThpError async def handle_pairing_request( - channel: Channel, message: ThpStartPairingRequest + channel: Channel, message: protobuf.MessageType ) -> ThpCodeEntryCommitment | None: + assert ThpStartPairingRequest.is_type_of(message) if __debug__: log.debug(__name__, "handle_pairing_request") _check_state(channel, ChannelState.TP1) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 65571c878..35bc1715d 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -7,7 +7,11 @@ from storage import cache_thp from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache from trezor import log, loop, protobuf, utils, workflow from trezor.enums import FailureType, MessageType -from trezor.messages import Failure, ThpCreateNewSession, ThpDeviceProperties +from trezor.messages import ( + Failure, + ThpCreateNewSession, + ThpHandshakeCompletionReqNoisePayload, +) from trezor.wire import message_handler from trezor.wire.thp import ack_handler, thp_messages from trezor.wire.thp.handler_provider import get_handler @@ -299,7 +303,7 @@ class Channel(Context): - CHECKSUM_LENGTH ] - device_properties = thp_messages.decode_message( + noise_payload = thp_messages.decode_message( self.buffer[ INIT_DATA_OFFSET + KEY_LENGTH @@ -308,11 +312,11 @@ class Channel(Context): - TAG_LENGTH ], 0, - "ThpDeviceProperties", + "ThpHandshakeCompletionReqNoisePayload", ) if TYPE_CHECKING: - assert isinstance(device_properties, ThpDeviceProperties) - for i in device_properties.pairing_methods: + assert isinstance(noise_payload, ThpHandshakeCompletionReqNoisePayload) + for i in noise_payload.pairing_methods: self.selected_pairing_methods.append(i) if __debug__: log.debug( @@ -322,7 +326,7 @@ class Channel(Context): utils.get_bytes_as_str(handshake_completion_request_noise_payload), ) - paired: bool = True # TODO should be output from credential check + paired: bool = False # TODO should be output from credential check # send hanshake completion response loop.schedule( diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index a3c1c75f7..5f4a5fa47 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -1,12 +1,13 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from trezor import log, loop, protobuf, workflow -from trezor.messages import ThpStartPairingRequest +from trezor.enums import MessageType from trezor.wire import message_handler, protocol_common from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.errors import ActionCancelled from trezor.wire.protocol_common import MessageWithType from trezor.wire.thp.session_context import UnexpectedMessageWithType +from trezor.wire.thp.thp_session import ThpError from apps.thp.pairing import handle_pairing_request @@ -138,9 +139,8 @@ async def handle_pairing_message( req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) # Create the handler task. - if TYPE_CHECKING: - assert isinstance(req_msg, ThpStartPairingRequest) # TODO remove task = handler(ctx.channel, req_msg) + # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a # response message, or raise an exception (a rather common @@ -189,7 +189,10 @@ async def handle_pairing_message( def get_handler(messageType: int): - return handle_pairing_request + if messageType == MessageType.ThpStartPairingRequest: + return handle_pairing_request + else: + raise ThpError("Handler for this method is not implemented yet") def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator: