1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): session creation, part 2

This commit is contained in:
M1nd3r 2024-03-27 17:22:07 +01:00
parent 912c85e21e
commit a245ef195e
8 changed files with 110 additions and 18 deletions

View File

@ -0,0 +1,38 @@
syntax = "proto2";
package hw.trezor.messages.thp;
// Sugar for easier handling in Java
option java_package = "com.satoshilabs.trezor.lib.protobuf";
option java_outer_classname = "TrezorMessageThp";
// Numeric identifiers of pairing methods.
enum PairingMethod {
PairingMethod_None = 1; // Trust without MITM protection.
PairingMethod_CodeEntry = 2; // User types code diplayed on Trezor into the host application.
PairingMethod_QrCode = 3; // User scans code displayed on Trezor into host application.
PairingMethod_NFC_Unidirectional = 4; // Trezor transmits an authentication key to the host device via NFC.
}
message DeviceProperties {
optional string internal_model = 1; // Internal model name e.g. "T2B1".
optional uint32 model_variant = 2; // Encodes the device properties such as color.
optional bool bootloader_mode = 3; // Indicates whether the device is in bootloader or firmware mode.
optional uint32 protocol_version = 4; // The communication protocol version supported by the firmware.
repeated PairingMethod pairing_methods = 5; // The pairing methods supported by the Trezor.
}
message HandshakeCompletionReqNoisePayload {
optional bytes host_pairing_credential = 1; // Host's pairing credential
repeated PairingMethod pairing_methods = 2; // The pairing methods chosen by the host
}
message CreateNewSession{
optional string passphrase = 1;
optional bool on_device = 2; // user wants to enter passphrase on the device
}
message NewSession{
optional uint32 new_session_id = 1;
}

View File

@ -375,4 +375,15 @@ enum MessageType {
MessageType_SolanaAddress = 903 [(wire_out) = true];
MessageType_SolanaSignTx = 904 [(wire_in) = true];
MessageType_SolanaTxSignature = 905 [(wire_out) = true];
// THP
MessageType_StartPairingRequest = 1000 [(bitcoin_only) = true, (wire_in) = true];
MessageType_StartPairingResponse = 1001 [(bitcoin_only) = true, (wire_out) = true];
MessageType_CredentialRequest = 1002 [(bitcoin_only) = true, (wire_in) = true];
MessageType_CredentialResponse = 1003 [(bitcoin_only) = true, (wire_out) = true];
MessageType_EndRequest = 1004 [(bitcoin_only) = true, (wire_in) = true];
MessageType_EndResponse = 1005 [(bitcoin_only) = true, (wire_out) = true];
MessageType_CreateNewSession = 1006[(bitcoin_only)=true,(wire_in)=true];
MessageType_NewSession = 1007[(bitcoin_only)=true,(wire_out)=true];
}

View File

@ -47,7 +47,7 @@ class ChannelCache(ConnectionCache):
self.state = bytearray(_CHANNEL_STATE_LENGTH)
self.iface = bytearray(1) # TODO add decoding
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.session_id_counter = 0x01
self.session_id_counter = 0x00
self.fields = ()
super().__init__()
@ -276,13 +276,14 @@ def get_next_channel_id() -> bytes:
def get_next_session_id(channel: ChannelCache) -> bytes:
while not _is_session_id_unique(channel):
while True:
if channel.session_id_counter >= 255:
channel.session_id_counter = 1
else:
channel.session_id_counter += 1
if _is_session_id_unique(channel):
break
new_sid = channel.session_id_counter
channel.session_id_counter += 1
return new_sid.to_bytes(_SESSION_ID_LENGTH, "big")

View File

@ -264,6 +264,7 @@ if TYPE_CHECKING:
SolanaAddress = 903
SolanaSignTx = 904
SolanaTxSignature = 905
CreateNewSession = 1006
class FailureType(IntEnum):
UnexpectedMessage = 1

View File

@ -372,6 +372,22 @@ if TYPE_CHECKING:
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["PassphraseAck"]:
return isinstance(msg, cls)
class CreateNewSession(protobuf.MessageType):
passphrase: "str | None"
on_device: "bool | None"
def __init__(
self,
*,
passphrase: "str | None" = None,
on_device: "bool | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["CreateNewSession"]:
return isinstance(msg, cls)
class HDNodeType(protobuf.MessageType):
depth: "int"

View File

@ -7,7 +7,8 @@ import usb
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache
from trezor import loop, protobuf, utils
from trezor.wire.thp import thp_messages
from trezor.messages import CreateNewSession
from trezor.wire import message_handler
from ..protocol_common import Context
from . import ChannelState, SessionState, checksum
@ -22,7 +23,7 @@ from .thp_messages import (
from .thp_session import ThpError
if TYPE_CHECKING:
from trezorio import WireInterface # type:ignore
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
_WIRE_INTERFACE_USB = b"\x01"
@ -182,6 +183,7 @@ class ChannelContext(Context):
# TODO ignore message
self._todo_clear_buffer()
return
if state is ChannelState.ENCRYPTED_TRANSPORT:
self._decrypt_buffer()
session_id, message_type = ustruct.unpack(
@ -189,13 +191,23 @@ class ChannelContext(Context):
)
if session_id == 0:
try:
message = thp_messages.decode_message(
self.buffer[INIT_DATA_OFFSET + 3 :], message_type
)
print(message)
except Exception as e:
print(e)
buf = self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
print(message)
# ------------------------------------------------TYPE ERROR------------------------------------------------
session_message: CreateNewSession = message
print("passphrase:", session_message.passphrase)
# await thp_messages.handle_CreateNewSession(message)
if session_message.passphrase is not None:
self.create_new_session(session_message.passphrase)
else:
self.create_new_session()
except Exception as e:
print("Proč??")
print(e)
return
# TODO not finished
if session_id not in self.sessions:
@ -255,8 +267,13 @@ class ChannelContext(Context):
self,
passphrase="",
) -> None: # TODO change it to output session data
pass
# create a new session with this passphrase
print("create new session")
from trezor.wire.thp.session_context import SessionContext
session = SessionContext.create_new_session(self)
print("help")
self.sessions[session.session_id] = session
print("new session created. Session id:", session.session_id)
# OTHER

View File

@ -18,12 +18,10 @@ class SessionContext(Context):
super().__init__(channel_context.iface, channel_context.channel_id)
self.channel_context = channel_context
self.session_cache = session_cache
self.session_id = session_cache.session_id
self.session_id = int.from_bytes(session_cache.session_id, "big")
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_context.write(
msg, int.from_bytes(self.session_id, "big")
)
return await self.channel_context.write(msg, self.session_id)
@classmethod
def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext":

View File

@ -2,6 +2,7 @@ import ustruct # pyright:ignore[reportMissingModuleSource]
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf
from trezor.messages import CreateNewSession
from .. import message_handler
from ..protocol_common import Message
@ -82,5 +83,14 @@ def get_handshake_init_response() -> bytes:
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:
print("decode message")
expected_type = protobuf.type_for_wire(msg_type)
return message_handler.wrap_protobuf_load(buffer, expected_type)
x = message_handler.wrap_protobuf_load(buffer, expected_type)
print("result decoded", x)
return x
async def handle_CreateNewSession(msg: CreateNewSession) -> None:
print(msg.passphrase)
print(msg.on_device)
pass