diff --git a/common/protob/messages-thp.proto b/common/protob/messages-thp.proto index 5d44f6a59..08361214a 100644 --- a/common/protob/messages-thp.proto +++ b/common/protob/messages-thp.proto @@ -62,7 +62,7 @@ message ThpNewSession{ * @next ThpNfcUnidirectionalTag // Sent by the Host */ message ThpStartPairingRequest{ - optional bytes host_name = 1; // Human-readable host name + optional string host_name = 1; // Human-readable host name } /** diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 477dcd7ee..52b9ced5c 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -1,33 +1,31 @@ -from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] - +from trezor import log +from trezor.enums import ThpPairingMethod +from trezor.messages import ( + ThpCodeEntryChallenge, + ThpCodeEntryCommitment, + ThpCodeEntryCpaceHost, + ThpCodeEntryCpaceTrezor, + ThpCodeEntrySecret, + ThpCodeEntryTag, + ThpNfcUnideirectionalSecret, + ThpNfcUnidirectionalTag, + ThpQrCodeSecret, + ThpQrCodeTag, + ThpStartPairingRequest, +) from trezor.wire.errors import UnexpectedMessage from trezor.wire.thp import ChannelState from trezor.wire.thp.channel import Channel from trezor.wire.thp.thp_session import ThpError -if TYPE_CHECKING: - from trezor.enums import ThpPairingMethod - from trezor.messages import ( - ThpCodeEntryChallenge, - ThpCodeEntryCommitment, - ThpCodeEntryCpaceHost, - ThpCodeEntryCpaceTrezor, - ThpCodeEntrySecret, - ThpCodeEntryTag, - ThpNfcUnideirectionalSecret, - ThpNfcUnidirectionalTag, - ThpQrCodeSecret, - ThpQrCodeTag, - ThpStartPairingRequest, - ) - - # TODO implement the following handlers async def handle_pairing_request( channel: Channel, message: ThpStartPairingRequest ) -> ThpCodeEntryCommitment | None: + if __debug__: + log.debug(__name__, "handle_pairing_request") _check_state(channel, ChannelState.TP1) if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry): channel.set_channel_state(ChannelState.TP2) diff --git a/core/src/trezor/messages.py b/core/src/trezor/messages.py index 900cf565a..b1efad505 100644 --- a/core/src/trezor/messages.py +++ b/core/src/trezor/messages.py @@ -6167,12 +6167,12 @@ if TYPE_CHECKING: return isinstance(msg, cls) class ThpStartPairingRequest(protobuf.MessageType): - host_name: "bytes | None" + host_name: "str | None" def __init__( self, *, - host_name: "bytes | None" = None, + host_name: "str | None" = None, ) -> None: pass diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 018f994e6..3bbdb8c8e 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -7,7 +7,7 @@ from storage import cache_thp from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache from trezor import log, loop, protobuf, utils, workflow from trezor.enums import FailureType, MessageType # , ThpPairingMethod -from trezor.messages import Failure +from trezor.messages import Failure, ThpDeviceProperties from trezor.wire import message_handler from trezor.wire.thp import ack_handler, thp_messages from trezor.wire.thp.handler_provider import get_handler @@ -61,9 +61,7 @@ class Channel(Context): self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read: int = 0 - self.selected_pairing_methods = ( - [] - ) # TODO better # ThpPairingMethod.PairingMethod_NoMethod + self.selected_pairing_methods = [] from trezor.wire.thp.session_context import load_cached_sessions self.connection_context = None @@ -300,6 +298,22 @@ class Channel(Context): + TAG_LENGTH : message_length - CHECKSUM_LENGTH ] + + device_properties = thp_messages.decode_message( + self.buffer[ + INIT_DATA_OFFSET + + KEY_LENGTH + + TAG_LENGTH : message_length + - CHECKSUM_LENGTH + - TAG_LENGTH + ], + 0, + "ThpDeviceProperties", + ) + if TYPE_CHECKING: + assert isinstance(device_properties, ThpDeviceProperties) + for i in device_properties.pairing_methods: + self.selected_pairing_methods.append(i) if __debug__: log.debug( __name__, @@ -308,13 +322,20 @@ class Channel(Context): utils.get_bytes_as_str(handshake_completion_request_noise_payload), ) + paired: bool = False # TODO should be output from credential check + # send hanshake completion response loop.schedule( self._write_encrypted_payload_loop( - HANDSHAKE_COMP_RES, thp_messages.get_handshake_completion_response() + HANDSHAKE_COMP_RES, + thp_messages.get_handshake_completion_response(paired=paired), ) ) - self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + # TODO add credential recognition + if paired: + self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + else: + self.set_channel_state(ChannelState.TP1) async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: if __debug__: @@ -354,7 +375,11 @@ class Channel(Context): if self.connection_context is None: self.connection_context = PairingContext(self) + loop.schedule(self.connection_context.handle()) + print("TEST selected methods") + for i in self.selected_pairing_methods: + print("method:", i) self._decrypt_buffer(message_length) message_type = ustruct.unpack( diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 85a63938d..a3c1c75f7 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -13,7 +13,7 @@ from apps.thp.pairing import handle_pairing_request from .channel import Channel if TYPE_CHECKING: - from typing import Container # pyright:ignore[reportShadowedImports] + from typing import Container, Generator # pyright:ignore[reportShadowedImports] pass @@ -34,9 +34,6 @@ class PairingContext: take = self.incoming_message.take() next_message: MessageWithType | None = None - # Take a mark of modules that are imported at this point, so we can - # roll back and un-import any others. - # TODO modules = utils.unimport_begin() while True: try: if next_message is None: @@ -74,8 +71,6 @@ class PairingContext: if next_message is None: # Shut down the loop if there is no next message waiting. - # Let the session be restarted from `main`. - loop.clear() return # pylint: disable=lost-exception except Exception as exc: @@ -146,7 +141,6 @@ async def handle_pairing_message( if TYPE_CHECKING: assert isinstance(req_msg, ThpStartPairingRequest) # TODO remove task = handler(ctx.channel, req_msg) - # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a # response message, or raise an exception (a rather common @@ -154,7 +148,7 @@ async def handle_pairing_message( if use_workflow: # Spawn a workflow around the task. This ensures that concurrent # workflows are shut down. - # res_msg = await workflow.spawn(context.with_context(ctx, task)) + res_msg = await workflow.spawn(with_context(ctx, task)) pass # TODO else: # For debug messages, ignore workflow processing and just await @@ -187,7 +181,6 @@ async def handle_pairing_message( else: log.exception(__name__, exc) res_msg = message_handler.failure(exc) - if res_msg is not None: # perform the write outside the big try-except block, so that usb write # problem bubbles up @@ -197,3 +190,34 @@ async def handle_pairing_message( def get_handler(messageType: int): return handle_pairing_request + + +def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator: + """Run a workflow in a particular context. + + Stores the context in a closure and installs it into the global variable every time + the closure is resumed, thus making sure that all calls to `wire.context.*` will + work as expected. + """ + global CURRENT_CONTEXT + send_val = None + send_exc = None + + while True: + CURRENT_CONTEXT = ctx + try: + if send_exc is not None: + res = workflow.throw(send_exc) + else: + res = workflow.send(send_val) + except StopIteration as st: + return st.value + finally: + CURRENT_CONTEXT = None + + try: + send_val = yield res + except BaseException as e: + send_exc = e + else: + send_exc = None diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index f1dadb8a4..886d196a4 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -1,7 +1,7 @@ import ustruct # pyright:ignore[reportMissingModuleSource] from storage.cache_thp import BROADCAST_CHANNEL_ID -from trezor import protobuf +from trezor import log, protobuf from .. import message_handler from ..protocol_common import Message @@ -90,16 +90,27 @@ def get_handshake_init_response() -> bytes: return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x40\x41\x42\x43\x44\x45\x46\x47\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15" -def get_handshake_completion_response() -> bytes: +def get_handshake_completion_response(paired: bool) -> bytes: + if paired: + return ( + TREZOR_STATE_PAIRED + + b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15" + ) return ( - TREZOR_STATE_PAIRED + TREZOR_STATE_UNPAIRED + b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15" ) -def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: - print("decode message") - expected_type = protobuf.type_for_wire(msg_type) +def decode_message( + buffer: bytes, msg_type: int, message_name: str | None = None +) -> protobuf.MessageType: + if __debug__: + log.debug(__name__, "decode message") + if message_name is not None: + expected_type = protobuf.type_for_name(message_name) + else: + expected_type = protobuf.type_for_wire(msg_type) x = message_handler.wrap_protobuf_load(buffer, expected_type) print("result decoded", x) return x diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index c78369046..07d80e02e 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -7804,13 +7804,13 @@ class ThpNewSession(protobuf.MessageType): class ThpStartPairingRequest(protobuf.MessageType): MESSAGE_WIRE_TYPE = 1008 FIELDS = { - 1: protobuf.Field("host_name", "bytes", repeated=False, required=False, default=None), + 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), } def __init__( self, *, - host_name: Optional["bytes"] = None, + host_name: Optional["str"] = None, ) -> None: self.host_name = host_name diff --git a/rust/trezor-client/src/protos/generated/messages_thp.rs b/rust/trezor-client/src/protos/generated/messages_thp.rs index b4319b4b3..ef9301b36 100644 --- a/rust/trezor-client/src/protos/generated/messages_thp.rs +++ b/rust/trezor-client/src/protos/generated/messages_thp.rs @@ -835,7 +835,7 @@ impl ::protobuf::reflect::ProtobufValue for ThpNewSession { pub struct ThpStartPairingRequest { // message fields // @@protoc_insertion_point(field:hw.trezor.messages.thp.ThpStartPairingRequest.host_name) - pub host_name: ::std::option::Option<::std::vec::Vec>, + pub host_name: ::std::option::Option<::std::string::String>, // special fields // @@protoc_insertion_point(special_field:hw.trezor.messages.thp.ThpStartPairingRequest.special_fields) pub special_fields: ::protobuf::SpecialFields, @@ -852,12 +852,12 @@ impl ThpStartPairingRequest { ::std::default::Default::default() } - // optional bytes host_name = 1; + // optional string host_name = 1; - pub fn host_name(&self) -> &[u8] { + pub fn host_name(&self) -> &str { match self.host_name.as_ref() { Some(v) => v, - None => &[], + None => "", } } @@ -870,22 +870,22 @@ impl ThpStartPairingRequest { } // Param is passed by value, moved - pub fn set_host_name(&mut self, v: ::std::vec::Vec) { + pub fn set_host_name(&mut self, v: ::std::string::String) { self.host_name = ::std::option::Option::Some(v); } // Mutable pointer to the field. // If field is not initialized, it is initialized with default value first. - pub fn mut_host_name(&mut self) -> &mut ::std::vec::Vec { + pub fn mut_host_name(&mut self) -> &mut ::std::string::String { if self.host_name.is_none() { - self.host_name = ::std::option::Option::Some(::std::vec::Vec::new()); + self.host_name = ::std::option::Option::Some(::std::string::String::new()); } self.host_name.as_mut().unwrap() } // Take field - pub fn take_host_name(&mut self) -> ::std::vec::Vec { - self.host_name.take().unwrap_or_else(|| ::std::vec::Vec::new()) + pub fn take_host_name(&mut self) -> ::std::string::String { + self.host_name.take().unwrap_or_else(|| ::std::string::String::new()) } fn generated_message_descriptor_data() -> ::protobuf::reflect::GeneratedMessageDescriptorData { @@ -915,7 +915,7 @@ impl ::protobuf::Message for ThpStartPairingRequest { while let Some(tag) = is.read_raw_tag_or_eof()? { match tag { 10 => { - self.host_name = ::std::option::Option::Some(is.read_bytes()?); + self.host_name = ::std::option::Option::Some(is.read_string()?); }, tag => { ::protobuf::rt::read_unknown_or_skip_group(tag, is, self.special_fields.mut_unknown_fields())?; @@ -930,7 +930,7 @@ impl ::protobuf::Message for ThpStartPairingRequest { fn compute_size(&self) -> u64 { let mut my_size = 0; if let Some(v) = self.host_name.as_ref() { - my_size += ::protobuf::rt::bytes_size(1, &v); + my_size += ::protobuf::rt::string_size(1, &v); } my_size += ::protobuf::rt::unknown_fields_size(self.special_fields.unknown_fields()); self.special_fields.cached_size().set(my_size as u32); @@ -939,7 +939,7 @@ impl ::protobuf::Message for ThpStartPairingRequest { fn write_to_with_cached_sizes(&self, os: &mut ::protobuf::CodedOutputStream<'_>) -> ::protobuf::Result<()> { if let Some(v) = self.host_name.as_ref() { - os.write_bytes(1, v)?; + os.write_string(1, v)?; } os.write_unknown_fields(self.special_fields.unknown_fields())?; ::std::result::Result::Ok(()) @@ -3236,9 +3236,9 @@ static file_descriptor_proto_data: &'static [u8] = b"\ \x12\x1e\n\npassphrase\x18\x01\x20\x01(\tR\npassphrase\x12\x1b\n\ton_dev\ ice\x18\x02\x20\x01(\x08R\x08onDevice\"5\n\rThpNewSession\x12$\n\x0enew_\ session_id\x18\x01\x20\x01(\rR\x0cnewSessionId\"5\n\x16ThpStartPairingRe\ - quest\x12\x1b\n\thost_name\x18\x01\x20\x01(\x0cR\x08hostName\"8\n\x16Thp\ - CodeEntryCommitment\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitm\ - ent\"5\n\x15ThpCodeEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\ + quest\x12\x1b\n\thost_name\x18\x01\x20\x01(\tR\x08hostName\"8\n\x16ThpCo\ + deEntryCommitment\x12\x1e\n\ncommitment\x18\x01\x20\x01(\x0cR\ncommitmen\ + t\"5\n\x15ThpCodeEntryChallenge\x12\x1c\n\tchallenge\x18\x01\x20\x01(\ \x0cR\tchallenge\"J\n\x15ThpCodeEntryCpaceHost\x121\n\x15cpace_host_publ\ ic_key\x18\x01\x20\x01(\x0cR\x12cpaceHostPublicKey\"P\n\x17ThpCodeEntryC\ paceTrezor\x125\n\x17cpace_trezor_public_key\x18\x01\x20\x01(\x0cR\x14cp\