diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 83e1837f2..c94c5f6c5 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -289,12 +289,10 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["NewTransport"]]: for transport in new_enumerate_devices(): try: - print("test A") client = NewTrezorClient(transport) - session = client.get_session() + session = client.get_management_session() description = format_device_name(session.features) # client.end_session() - print("after end session") except DeviceIsBusy: description = "Device is in use by another process" except Exception: diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index 3998f1443..b0bcd344a 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -17,6 +17,7 @@ from __future__ import annotations import io +import logging from types import ModuleType from typing import Dict, Optional, Tuple, Type, TypeVar @@ -25,6 +26,7 @@ from typing_extensions import Self from . import messages, protobuf T = TypeVar("T") +LOG = logging.getLogger(__name__) class ProtobufMapping: @@ -63,7 +65,7 @@ class ProtobufMapping: wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) if wire_type is None: raise ValueError("Cannot encode class without wire type") - print("wire type", wire_type) + LOG.debug("encoding wire type %d", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) return wire_type, buf.getvalue() diff --git a/python/src/trezorlib/transport/new/client.py b/python/src/trezorlib/transport/new/client.py index 025008466..115bf5987 100644 --- a/python/src/trezorlib/transport/new/client.py +++ b/python/src/trezorlib/transport/new/client.py @@ -14,6 +14,8 @@ LOG = logging.getLogger(__name__) class NewTrezorClient: + management_session: Session | None = None + def __init__( self, transport: NewTransport, @@ -26,10 +28,8 @@ class NewTrezorClient: self.mapping = mapping.DEFAULT_MAPPING else: self.mapping = protobuf_mapping - print("test B") if protocol is None: - print("test C") self.protocol = self._get_protocol() else: self.protocol = protocol @@ -40,7 +40,10 @@ class NewTrezorClient: ) -> NewTrezorClient: ... def get_session( - self, passphrase: str = "", derive_cardano: bool = False + self, + passphrase: str = "", + derive_cardano: bool = False, + management_session: bool = False, ) -> Session: if isinstance(self.protocol, ProtocolV1): return SessionV1.new(self, passphrase, derive_cardano) @@ -48,6 +51,18 @@ class NewTrezorClient: return SessionV2.new(self, passphrase, derive_cardano) raise NotImplementedError # TODO + def get_management_session(self): + if self.management_session is not None: + return self.management_session + + if isinstance(self.protocol, ProtocolV1): + self.management_session = SessionV1.new(self, "", False) + elif isinstance(self.protocol, ProtocolV2): + self.management_session = SessionV2(self, b"\x00") + + assert self.management_session is not None + return self.management_session + def resume_session(self, session_id: bytes) -> Session: raise NotImplementedError # TODO @@ -66,7 +81,6 @@ class NewTrezorClient: response = protocol.read() self.transport.close() if isinstance(response, messages.Failure): - print("test F1") if ( response.code == FailureType.UnexpectedMessage and response.message == "Invalid protocol" diff --git a/python/src/trezorlib/transport/new/protocol_and_channel.py b/python/src/trezorlib/transport/new/protocol_and_channel.py index 083d87928..011be6c37 100644 --- a/python/src/trezorlib/transport/new/protocol_and_channel.py +++ b/python/src/trezorlib/transport/new/protocol_and_channel.py @@ -61,11 +61,9 @@ class ProtocolV1(ProtocolAndChannel): self._write(msg_type, msg_bytes) def _write(self, message_type: int, message_data: bytes) -> None: - print("wooooo") chunk_size = self.transport.CHUNK_SIZE header = struct.pack(">HL", message_type, len(message_data)) buffer = bytearray(b"##" + header + message_data) - print("wooooo") while buffer: # Report ID, data padded to 63 bytes diff --git a/python/src/trezorlib/transport/new/protocol_v2.py b/python/src/trezorlib/transport/new/protocol_v2.py index f5d5c0145..971053206 100644 --- a/python/src/trezorlib/transport/new/protocol_v2.py +++ b/python/src/trezorlib/transport/new/protocol_v2.py @@ -54,6 +54,7 @@ class ProtocolV2(ProtocolAndChannel): sync_bit_send: int sync_bit_receive: int has_valid_channel: bool = False + features: messages.Features def __init__( self, @@ -67,6 +68,7 @@ class ProtocolV2(ProtocolAndChannel): def get_channel(self) -> ProtocolV2: if not self.has_valid_channel: self._establish_new_channel() + self.update_features() return self def read(self, session_id: int) -> t.Any: @@ -77,7 +79,25 @@ class ProtocolV2(ProtocolAndChannel): msg_type, msg_data = self.mapping.encode(msg) self._encrypt_and_write(session_id, msg_type, msg_data, 7) # TODO add ctrl_byte - def _establish_new_channel(self): + def update_features(self) -> None: + message = messages.GetFeatures() + message_type, message_data = self.mapping.encode(message) + + self.session_id: int = 0 + self._encrypt_and_write( + MANAGEMENT_SESSION_ID, + message_type, + message_data, + 0x14, # TODO update control byte + ) + _ = self._read_until_valid_crc_check() # TODO check ACK + session_id, msg_type, msg_data = self.read_and_decrypt() + features = self.mapping.decode(msg_type, msg_data) + assert isinstance(features, messages.Features) + self.features = features + self._send_ack_2() + + def _establish_new_channel(self) -> None: self.sync_bit_send = 0 self.sync_bit_receive = 0 # Send channel allocation request @@ -124,7 +144,7 @@ class ProtocolV2(ProtocolAndChannel): noise_tag = payload[80:96] # TODO check noise tag - print("noise_tag: ", hexlify(noise_tag).decode()) + LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) # Prepare and send handshake completion request PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" @@ -143,7 +163,6 @@ class ProtocolV2(ProtocolAndChannel): trezor_masked_static_pubkey = aes_ctx.decrypt( IV_1, encrypted_trezor_static_pubkey, h ) - # print("masked_key", hexlify(trezor_masked_static_pubkey).decode()) except Exception as e: print(type(e)) # TODO how to handle potential exceptions? Q for Matejcik h = _sha256_of_two(h, encrypted_trezor_static_pubkey) @@ -259,7 +278,7 @@ class ProtocolV2(ProtocolAndChannel): self.transport, header, encrypted_message ) - def read_and_decrypt(self) -> t.Tuple[bytes, int, bytes]: + def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: header, raw_payload = self._read_until_valid_crc_check() if not header.is_encrypted_transport(): print("Trying to decrypt not encrypted message!") @@ -272,7 +291,7 @@ class ProtocolV2(ProtocolAndChannel): message_type = message[1:3] message_data = message[3:] return ( - int.to_bytes(session_id, 1, "big"), + session_id, int.from_bytes(message_type, "big"), message_data, ) diff --git a/python/src/trezorlib/transport/new/session.py b/python/src/trezorlib/transport/new/session.py index 6c04f9eed..8cd2b57e0 100644 --- a/python/src/trezorlib/transport/new/session.py +++ b/python/src/trezorlib/transport/new/session.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from ...messages import Features, Initialize +from ...messages import Features, Initialize, ThpCreateNewSession, ThpNewSession from .protocol_and_channel import ProtocolV1 from .protocol_v2 import ProtocolV2 @@ -34,11 +34,11 @@ class SessionV1(Session): ) -> SessionV1: assert isinstance(client.protocol, ProtocolV1) session = SessionV1(client, b"") - cls.features = session.call( + session.features = session.call( # Initialize(passphrase=passphrase, derive_cardano=derive_cardano) # TODO Initialize() ) - session.id = cls.features.session_id + session.id = session.features.session_id return session def call(self, msg: t.Any, should_reinit: bool = False) -> t.Any: @@ -51,15 +51,32 @@ class SessionV1(Session): class SessionV2(Session): + + @classmethod + def new( + cls, client: NewTrezorClient, passphrase: str, derive_cardano: bool + ) -> SessionV2: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV2(client, b"\x00") + new_session: ThpNewSession = session.call( + ThpCreateNewSession(passphrase=passphrase, derive_cardano=derive_cardano) + ) + assert new_session.new_session_id is not None + session_id = new_session.new_session_id + session.update_id_and_sid(session_id.to_bytes(1, "big")) + return session + def __init__(self, client: NewTrezorClient, id: bytes) -> None: super().__init__(client, id) assert isinstance(client.protocol, ProtocolV2) - self.channel = client.protocol.get_channel() - self.sid = self._convert_id_to_sid(id) + self.channel: ProtocolV2 = client.protocol.get_channel() + self.update_id_and_sid(id) + self.features = self.channel.features def call(self, msg: t.Any) -> t.Any: self.channel.write(self.sid, msg) return self.channel.read(self.sid) - def _convert_id_to_sid(self, id: bytes) -> int: - return int.from_bytes(id, "big") # TODO update to extract only sid + def update_id_and_sid(self, id: bytes) -> None: + self.id = id + self.sid = int.from_bytes(id, "big") # TODO update to extract only sid