diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 08463063ca..33f3a15b1f 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -26,7 +26,7 @@ from .tools import parse_path from .transport import Transport, get_transport from .transport.thp.protocol_and_channel import Channel from .transport.thp.protocol_v1 import ProtocolV1Channel -from .transport.thp.protocol_v2 import ProtocolV2Channel +from .transport.thp.protocol_v2 import ProtocolV2Channel, TrezorState if t.TYPE_CHECKING: from .transport.session import Session, SessionV1 @@ -62,6 +62,7 @@ class TrezorClient: _setup_pin: str | None = None # Should be used only by conftest _last_active_session: SessionV1 | None = None + _session_id_counter: int = 0 def __init__( self, transport: Transport, @@ -99,6 +100,26 @@ class TrezorClient: else: raise Exception("Unknown protocol version") + def do_pairing(self) -> None: + from .transport.session import SessionV2 + + assert self.protocol_version == ProtocolVersion.V2 + session = SessionV2(client=self, id=b"\x00") + session.call( + messages.ThpPairingRequest(host_name="Trezorlib"), + expect=messages.ThpPairingRequestApproved, + skip_firmware_version_check=True, + ) + session.call( + messages.ThpSelectMethod( + selected_pairing_method=messages.ThpPairingMethod.SkipPairing + ), + expect=messages.ThpEndResponse, + skip_firmware_version_check=True, + ) + assert isinstance(self.protocol, ProtocolV2Channel) + self.protocol._has_valid_channel = True + def get_session( self, passphrase: str | object = "", @@ -128,17 +149,21 @@ class TrezorClient: if isinstance(self.protocol, ProtocolV2Channel): from .transport.session import SessionV2 + if self.protocol.trezor_state is TrezorState.UNPAIRED: + self.do_pairing() + if passphrase is SEEDLESS: return SessionV2(self, id=b"\x00") + if self._session_id_counter >= 255: + self._session_id_counter = 0 assert isinstance(passphrase, str) or passphrase is None - session_id = b"\x01" # TODO fix this with ProtocolV2 session rework - if session_id is not None: - sid = int.from_bytes(session_id, "big") - else: - sid = 1 - assert 0 <= sid <= 255 - return SessionV2.new(self, passphrase, derive_cardano, sid) + self._session_id_counter += 1 + + return SessionV2.new( + self, passphrase, derive_cardano, self._session_id_counter + ) + raise NotImplementedError def get_seedless_session(self) -> Session: @@ -150,11 +175,20 @@ class TrezorClient: @property def features(self) -> messages.Features: if self._features is None: - self._features = self.protocol.get_features() + self._features = self._get_features() self.check_firmware_version(warn_only=True) assert self._features is not None return self._features + def _get_features(self) -> messages.Features: + if isinstance(self.protocol, ProtocolV2Channel): + if ( + self.protocol.trezor_state is TrezorState.UNPAIRED + or not self.protocol._has_valid_channel + ): + self.do_pairing() + return self.protocol.get_features() + @property def protocol_version(self) -> int: return self._protocol_version diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 698147b805..56409f722e 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1030,6 +1030,8 @@ class TrezorClientDebugLink(TrezorClient): # without special DebugLink interface provided # by the device. + protocol: ProtocolV1Channel | ProtocolV2Channel + def __init__( self, transport: Transport, @@ -1075,8 +1077,8 @@ class TrezorClientDebugLink(TrezorClient): # and know the supported debug capabilities if self.protocol_version is ProtocolVersion.V2: assert isinstance(self.protocol, ProtocolV2Channel) - self.protocol._helper_debug = self.debug - self.protocol = self.protocol.get_channel() + self.do_pairing() + # self.protocol = self.protocol.get_channel() self.debug.model = self.model self.debug.version = self.version diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 60212e6d80..7a0af18e20 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -26,9 +26,11 @@ class Session: self, msg: MessageType, expect: type[MT] = MessageType, + skip_firmware_version_check: bool = False, _passphrase_ack: messages.PassphraseAck | None = None, ) -> MT: - self.client.check_firmware_version() + if not skip_firmware_version_check: + self.client.check_firmware_version() resp = self.call_raw(msg) while True: @@ -260,14 +262,11 @@ class SessionV2(Session): return session def __init__(self, client: TrezorClient, id: bytes) -> None: - from ..debuglink import TrezorClientDebugLink super().__init__(client, id) assert isinstance(client.protocol, ProtocolV2Channel) - if isinstance(client, TrezorClientDebugLink): - client.protocol._helper_debug = client.debug - self.channel: ProtocolV2Channel = client.protocol.get_channel() + self.channel: ProtocolV2Channel = client.protocol self.update_id_and_sid(id) def _write(self, msg: t.Any) -> None: diff --git a/python/src/trezorlib/transport/thp/protocol_and_channel.py b/python/src/trezorlib/transport/thp/protocol_and_channel.py index 731caa963e..b2e7dbc28f 100644 --- a/python/src/trezorlib/transport/thp/protocol_and_channel.py +++ b/python/src/trezorlib/transport/thp/protocol_and_channel.py @@ -1,14 +1,7 @@ -from __future__ import annotations - -import logging -import typing as t - from ... import messages from ...mapping import ProtobufMapping from .. import Transport -LOG = logging.getLogger(__name__) - class Channel: @@ -25,9 +18,3 @@ class Channel: def update_features(self) -> None: raise NotImplementedError - - def read(self, timeout: float | None = None) -> t.Any: - raise NotImplementedError - - def write(self, msg: t.Any) -> None: - raise NotImplementedError diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 1651c05b71..a703638933 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -4,6 +4,7 @@ import logging import os import typing as t from binascii import hexlify +from enum import IntEnum from noise.connection import Keypair, NoiseConnection @@ -21,10 +22,15 @@ LOG = logging.getLogger(__name__) DEFAULT_SESSION_ID: int = 0 if t.TYPE_CHECKING: - from ...debuglink import DebugLink + pass MT = t.TypeVar("MT", bound=protobuf.MessageType) +class TrezorState(IntEnum): + UNPAIRED = 0x00 + PAIRED = 0x01 + + class ProtocolV2Channel(Channel): channel_id: int sync_bit_send: int @@ -33,18 +39,20 @@ class ProtocolV2Channel(Channel): _has_valid_channel: bool = False _features: messages.Features | None = None - _helper_debug: DebugLink | None = None + trezor_state: int = TrezorState.UNPAIRED def __init__( self, transport: Transport, mapping: ProtobufMapping, + credential: bytes | None = None, ) -> None: super().__init__(transport, mapping) + self.trezor_state = self.prepare_channel_without_pairing(credential=credential) def get_channel(self) -> ProtocolV2Channel: if not self._has_valid_channel: - self._establish_new_channel(self._helper_debug) + raise RuntimeError("Channel is invalidated") return self def read(self, session_id: int) -> t.Any: @@ -61,7 +69,7 @@ class ProtocolV2Channel(Channel): def get_features(self) -> messages.Features: if not self._has_valid_channel: - self._establish_new_channel(self._helper_debug) + raise RuntimeError("Channel is invalidated") if self._features is None: self.update_features() assert self._features is not None @@ -96,11 +104,10 @@ class ProtocolV2Channel(Channel): assert isinstance(msg, message_type) return msg - def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None: + def prepare_channel_without_pairing(self, credential: bytes | None = None) -> int: self._reset_sync_bits() self._do_channel_allocation() - self._do_handshake() - self._do_pairing(helper_debug) + return self._do_handshake(credential=credential) def _reset_sync_bits(self) -> None: self.sync_bit_send = 0 @@ -148,7 +155,7 @@ class ProtocolV2Channel(Channel): self, credential: bytes | None = None, host_static_randomness: bytes | None = None, - ): + ) -> int: randomness_static = host_static_randomness or os.urandom(32) @@ -160,7 +167,7 @@ class ProtocolV2Channel(Channel): credential, ) self._read_ack() - self._read_handshake_completion_response() + return self._read_handshake_completion_response() def _send_handshake_init_request(self) -> None: ha_init_req_header = MessageHeader(0, self.channel_id, 36) @@ -215,7 +222,7 @@ class ProtocolV2Channel(Channel): ) self.handshake_hash = self._noise.get_handshake_hash() - def _read_handshake_completion_response(self) -> None: + def _read_handshake_completion_response(self) -> int: # Read handshake completion response, ignore payload as we do not care about the state header, data = self._read_until_valid_crc_check() if not header.is_handshake_comp_response(): @@ -228,25 +235,7 @@ class ProtocolV2Channel(Channel): print("trezor state:", trezor_state) assert trezor_state == b"\x00" or trezor_state == b"\x01" self._send_ack_1() - - def _do_pairing(self, helper_debug: DebugLink | None): - - self._send_message(messages.ThpPairingRequest(host_name="Trezorlib")) - self._read_message(messages.ButtonRequest) - self._send_message(messages.ButtonAck()) - - if helper_debug is not None: - helper_debug.press_yes() - - self._read_message(messages.ThpPairingRequestApproved) - self._send_message( - messages.ThpSelectMethod( - selected_pairing_method=messages.ThpPairingMethod.SkipPairing - ) - ) - self._read_message(messages.ThpEndResponse) - - self._has_valid_channel = True + return int.from_bytes(trezor_state, "big") def _read_ack(self): header, payload = self._read_until_valid_crc_check() diff --git a/tests/device_tests/thp/connect.py b/tests/device_tests/thp/connect.py index 2b6513d449..26678de128 100644 --- a/tests/device_tests/thp/connect.py +++ b/tests/device_tests/thp/connect.py @@ -27,11 +27,11 @@ def prepare_protocol_for_pairing( def get_encrypted_transport_protocol( client: Client, host_static_randomness: bytes | None = None ) -> ProtocolV2Channel: - protocol = prepare_protocol_for_pairing( + client.protocol = prepare_protocol_for_pairing( client, host_static_randomness=host_static_randomness ) - protocol._do_pairing(client.debug) - return protocol + client.do_pairing() + return client.protocol def handle_pairing_request( diff --git a/tests/device_tests/thp/test_handshake.py b/tests/device_tests/thp/test_handshake.py index 7a31ff3644..8bc814fccf 100644 --- a/tests/device_tests/thp/test_handshake.py +++ b/tests/device_tests/thp/test_handshake.py @@ -47,4 +47,5 @@ def test_handshake(client: Client) -> None: # TODO - without pairing, the client is damaged and results in fail of the following test # so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error - protocol._do_pairing(client.debug) + client.protocol = protocol + client.do_pairing() diff --git a/tests/device_tests/thp/test_multiple_hosts.py b/tests/device_tests/thp/test_multiple_hosts.py index 6f9d708b75..b74531c62b 100644 --- a/tests/device_tests/thp/test_multiple_hosts.py +++ b/tests/device_tests/thp/test_multiple_hosts.py @@ -50,10 +50,12 @@ def _prepare_two_hosts(client: Client) -> tuple[ProtocolV2Channel, ProtocolV2Cha ) protocol_1._do_handshake() - protocol_1._do_pairing(client.debug) + client.protocol = protocol_1 + client.do_pairing() sleep(LOCK_TIME) protocol_2._do_handshake() - protocol_2._do_pairing(client.debug) + client.protocol = protocol_2 + client.do_pairing() return protocol_1, protocol_2 @@ -122,7 +124,8 @@ def test_concurrent_handshakes_1(client: Client) -> None: # The second host performs action that results # in the invalidation of the first host's handshake state - protocol_2._do_pairing(helper_debug=client.debug) + client.protocol = protocol_2 + client.do_pairing() # Even after LOCK_TIME passes, the first host's channel cannot # be resumed diff --git a/tests/device_tests/thp/test_pairing.py b/tests/device_tests/thp/test_pairing.py index 13ecf18455..845253fb71 100644 --- a/tests/device_tests/thp/test_pairing.py +++ b/tests/device_tests/thp/test_pairing.py @@ -304,16 +304,16 @@ def test_channel_replacement(client: Client) -> None: client.protocol = get_encrypted_transport_protocol(client, host_static_randomness) - session = client.get_session(passphrase="TREZOR", session_id=b"\x10") + session = client.get_session(passphrase="TREZOR") address = get_test_address(session) - session_2 = client.get_session(passphrase="ROZERT", session_id=b"\x20") + session_2 = client.get_session(passphrase="ROZERT") address_2 = get_test_address(session_2) assert address != address_2 # create new channel using the same host_static_privkey client.protocol = get_encrypted_transport_protocol(client, host_static_randomness) - session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30") + session_3 = client.get_session(passphrase="OKIDOKI") address_3 = get_test_address(session_3) assert address_3 != address_2 @@ -333,6 +333,6 @@ def test_channel_replacement(client: Client) -> None: _ = get_test_address(session_3) assert str(e_2.value.message) == "Invalid session" - session_4 = client.get_session(passphrase="TREZOR", session_id=b"\x40") + session_4 = client.get_session(passphrase="TREZOR") super_new_address = get_test_address(session_4) assert address == super_new_address