From a245ef195e707962708881f437d02b06cf6c3669 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 27 Mar 2024 17:22:07 +0100 Subject: [PATCH] feat(core): session creation, part 2 --- common/protob/messages-thp.proto | 38 +++++++++++++++++++++ common/protob/messages.proto | 11 ++++++ core/src/storage/cache_thp.py | 7 ++-- core/src/trezor/enums/__init__.py | 1 + core/src/trezor/messages.py | 16 +++++++++ core/src/trezor/wire/thp/channel_context.py | 37 ++++++++++++++------ core/src/trezor/wire/thp/session_context.py | 6 ++-- core/src/trezor/wire/thp/thp_messages.py | 12 ++++++- 8 files changed, 110 insertions(+), 18 deletions(-) create mode 100644 common/protob/messages-thp.proto diff --git a/common/protob/messages-thp.proto b/common/protob/messages-thp.proto new file mode 100644 index 000000000..41309101b --- /dev/null +++ b/common/protob/messages-thp.proto @@ -0,0 +1,38 @@ +syntax = "proto2"; +package hw.trezor.messages.thp; + +// Sugar for easier handling in Java +option java_package = "com.satoshilabs.trezor.lib.protobuf"; +option java_outer_classname = "TrezorMessageThp"; + + + +// Numeric identifiers of pairing methods. +enum PairingMethod { + PairingMethod_None = 1; // Trust without MITM protection. + PairingMethod_CodeEntry = 2; // User types code diplayed on Trezor into the host application. + PairingMethod_QrCode = 3; // User scans code displayed on Trezor into host application. + PairingMethod_NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC. +} + +message DeviceProperties { + optional string internal_model = 1; // Internal model name e.g. "T2B1". + optional uint32 model_variant = 2; // Encodes the device properties such as color. + optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode. + optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware. + repeated PairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor. +} + +message HandshakeCompletionReqNoisePayload { + optional bytes host_pairing_credential = 1; // Host's pairing credential + repeated PairingMethod pairing_methods = 2; // The pairing methods chosen by the host +} + +message CreateNewSession{ + optional string passphrase = 1; + optional bool on_device = 2; // user wants to enter passphrase on the device +} + +message NewSession{ + optional uint32 new_session_id = 1; +} diff --git a/common/protob/messages.proto b/common/protob/messages.proto index f0a5d0cf5..f274aebbf 100644 --- a/common/protob/messages.proto +++ b/common/protob/messages.proto @@ -375,4 +375,15 @@ enum MessageType { MessageType_SolanaAddress = 903 [(wire_out) = true]; MessageType_SolanaSignTx = 904 [(wire_in) = true]; MessageType_SolanaTxSignature = 905 [(wire_out) = true]; + + // THP + MessageType_StartPairingRequest = 1000 [(bitcoin_only) = true, (wire_in) = true]; + MessageType_StartPairingResponse = 1001 [(bitcoin_only) = true, (wire_out) = true]; + MessageType_CredentialRequest = 1002 [(bitcoin_only) = true, (wire_in) = true]; + MessageType_CredentialResponse = 1003 [(bitcoin_only) = true, (wire_out) = true]; + MessageType_EndRequest = 1004 [(bitcoin_only) = true, (wire_in) = true]; + MessageType_EndResponse = 1005 [(bitcoin_only) = true, (wire_out) = true]; + MessageType_CreateNewSession = 1006[(bitcoin_only)=true,(wire_in)=true]; + MessageType_NewSession = 1007[(bitcoin_only)=true,(wire_out)=true]; + } diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index f3a885414..9a9b8e220 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -47,7 +47,7 @@ class ChannelCache(ConnectionCache): self.state = bytearray(_CHANNEL_STATE_LENGTH) self.iface = bytearray(1) # TODO add decoding self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5) - self.session_id_counter = 0x01 + self.session_id_counter = 0x00 self.fields = () super().__init__() @@ -276,13 +276,14 @@ def get_next_channel_id() -> bytes: def get_next_session_id(channel: ChannelCache) -> bytes: - while not _is_session_id_unique(channel): + while True: if channel.session_id_counter >= 255: channel.session_id_counter = 1 else: channel.session_id_counter += 1 + if _is_session_id_unique(channel): + break new_sid = channel.session_id_counter - channel.session_id_counter += 1 return new_sid.to_bytes(_SESSION_ID_LENGTH, "big") diff --git a/core/src/trezor/enums/__init__.py b/core/src/trezor/enums/__init__.py index 3335f1b27..cadc7ae17 100644 --- a/core/src/trezor/enums/__init__.py +++ b/core/src/trezor/enums/__init__.py @@ -264,6 +264,7 @@ if TYPE_CHECKING: SolanaAddress = 903 SolanaSignTx = 904 SolanaTxSignature = 905 + CreateNewSession = 1006 class FailureType(IntEnum): UnexpectedMessage = 1 diff --git a/core/src/trezor/messages.py b/core/src/trezor/messages.py index a4d6b4ee2..d7dec6b05 100644 --- a/core/src/trezor/messages.py +++ b/core/src/trezor/messages.py @@ -372,6 +372,22 @@ if TYPE_CHECKING: @classmethod def is_type_of(cls, msg: Any) -> TypeGuard["PassphraseAck"]: return isinstance(msg, cls) + + class CreateNewSession(protobuf.MessageType): + passphrase: "str | None" + on_device: "bool | None" + + def __init__( + self, + *, + passphrase: "str | None" = None, + on_device: "bool | None" = None, + ) -> None: + pass + + @classmethod + def is_type_of(cls, msg: Any) -> TypeGuard["CreateNewSession"]: + return isinstance(msg, cls) class HDNodeType(protobuf.MessageType): depth: "int" diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index 74e93e02c..547c1ca04 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -7,7 +7,8 @@ import usb from storage import cache_thp from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache from trezor import loop, protobuf, utils -from trezor.wire.thp import thp_messages +from trezor.messages import CreateNewSession +from trezor.wire import message_handler from ..protocol_common import Context from . import ChannelState, SessionState, checksum @@ -22,7 +23,7 @@ from .thp_messages import ( from .thp_session import ThpError if TYPE_CHECKING: - from trezorio import WireInterface # type:ignore + from trezorio import WireInterface # pyright:ignore[reportMissingImports] _WIRE_INTERFACE_USB = b"\x01" @@ -182,6 +183,7 @@ class ChannelContext(Context): # TODO ignore message self._todo_clear_buffer() return + if state is ChannelState.ENCRYPTED_TRANSPORT: self._decrypt_buffer() session_id, message_type = ustruct.unpack( @@ -189,13 +191,23 @@ class ChannelContext(Context): ) if session_id == 0: try: - message = thp_messages.decode_message( - self.buffer[INIT_DATA_OFFSET + 3 :], message_type - ) - print(message) - except Exception as e: - print(e) + buf = self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH] + expected_type = protobuf.type_for_wire(message_type) + message = message_handler.wrap_protobuf_load(buf, expected_type) + print(message) + # ------------------------------------------------TYPE ERROR------------------------------------------------ + session_message: CreateNewSession = message + print("passphrase:", session_message.passphrase) + # await thp_messages.handle_CreateNewSession(message) + if session_message.passphrase is not None: + self.create_new_session(session_message.passphrase) + else: + self.create_new_session() + except Exception as e: + print("Proč??") + print(e) + return # TODO not finished if session_id not in self.sessions: @@ -255,8 +267,13 @@ class ChannelContext(Context): self, passphrase="", ) -> None: # TODO change it to output session data - pass - # create a new session with this passphrase + print("create new session") + from trezor.wire.thp.session_context import SessionContext + + session = SessionContext.create_new_session(self) + print("help") + self.sessions[session.session_id] = session + print("new session created. Session id:", session.session_id) # OTHER diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 78e1afe40..a766e15cf 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -18,12 +18,10 @@ class SessionContext(Context): super().__init__(channel_context.iface, channel_context.channel_id) self.channel_context = channel_context self.session_cache = session_cache - self.session_id = session_cache.session_id + self.session_id = int.from_bytes(session_cache.session_id, "big") async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel_context.write( - msg, int.from_bytes(self.session_id, "big") - ) + return await self.channel_context.write(msg, self.session_id) @classmethod def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext": diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index 6f07a8855..03c7e6d7b 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -2,6 +2,7 @@ import ustruct # pyright:ignore[reportMissingModuleSource] from storage.cache_thp import BROADCAST_CHANNEL_ID from trezor import protobuf +from trezor.messages import CreateNewSession from .. import message_handler from ..protocol_common import Message @@ -82,5 +83,14 @@ def get_handshake_init_response() -> bytes: def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: + print("decode message") expected_type = protobuf.type_for_wire(msg_type) - return message_handler.wrap_protobuf_load(buffer, expected_type) + x = message_handler.wrap_protobuf_load(buffer, expected_type) + print("result decoded", x) + return x + + +async def handle_CreateNewSession(msg: CreateNewSession) -> None: + print(msg.passphrase) + print(msg.on_device) + pass