diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index da2a126eba..0b14778ed7 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -26,7 +26,7 @@ from contextlib import contextmanager import click from .. import exceptions, transport, ui -from ..client import PROTOCOL_V2, TrezorClient +from ..client import ProtocolVersion, TrezorClient from ..messages import Capability from ..transport import Transport from ..transport.session import Session, SessionV1, SessionV2 @@ -150,11 +150,11 @@ class TrezorConnection: # Try resume session from id if self.session_id is not None: - if client.protocol_version is Session.CODEC_V1: + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: session = SessionV1.resume_from_id( client=client, session_id=self.session_id ) - elif client.protocol_version is Session.THP_V2: + elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: session = SessionV2(client, self.session_id) # TODO fix resumption on THP else: @@ -311,7 +311,7 @@ def with_client( try: return func(client, *args, **kwargs) finally: - if client.protocol_version == PROTOCOL_V2: + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: get_channel_db().save_channel(client.protocol) # if not session_was_resumed: # try: diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index da784a5a6b..b3a885e4c8 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca import click from .. import __version__, log, messages, protobuf -from ..client import TrezorClient +from ..client import ProtocolVersion, TrezorClient from ..transport import DeviceIsBusy, enumerate_devices from ..transport.session import Session from ..transport.thp import channel_database @@ -308,7 +308,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: try: client = get_client(transport) description = format_device_name(client.features) - if client.protocol_version == Session.THP_V2: + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 6bb8a2a27d..d82554dd93 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -18,6 +18,7 @@ from __future__ import annotations import logging import os import typing as t +from enum import IntEnum from . import mapping, messages, models from .mapping import ProtobufMapping @@ -48,9 +49,11 @@ Or visit https://suite.trezor.io/ LOG = logging.getLogger(__name__) -UNKNOWN = -1 -PROTOCOL_V1 = 1 -PROTOCOL_V2 = 2 + +class ProtocolVersion(IntEnum): + UNKNOWN = 0x00 + PROTOCOL_V1 = 0x01 # Codec + PROTOCOL_V2 = 0x02 # THP class TrezorClient: @@ -80,12 +83,13 @@ class TrezorClient: else: self.protocol = protocol self.protocol.mapping = self.mapping + if isinstance(self.protocol, ProtocolV1): - self._protocol_version = PROTOCOL_V1 + self._protocol_version = ProtocolVersion.PROTOCOL_V1 elif isinstance(self.protocol, ProtocolV2): - self._protocol_version = PROTOCOL_V2 + self._protocol_version = ProtocolVersion.PROTOCOL_V2 else: - self._protocol_version = UNKNOWN + self._protocol_version = ProtocolVersion.UNKNOWN @classmethod def resume( diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index ba24b8109a..707401cf1b 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -42,7 +42,7 @@ from .exceptions import Cancelled, PinException, TrezorFailure from .log import DUMP_BYTES from .messages import Capability, DebugWaitType from .tools import expect, parse_path -from .transport.session import Session, SessionV1, SessionV2 +from .transport.session import Session, SessionV1 from .transport.thp.protocol_v1 import ProtocolV1 if t.TYPE_CHECKING: @@ -1031,18 +1031,12 @@ class SessionDebugWrapper(Session): def __init__(self, session: Session) -> None: self._session = session self.reset_debug_features() - if isinstance(session, SessionV1): - self.client.session_version = 1 - elif isinstance(session, SessionV2): - self.client.session_version = 2 - elif isinstance(session, SessionDebugWrapper): + if isinstance(session, SessionDebugWrapper): raise Exception("Cannot wrap already wrapped session!") - else: - self.client.session_version = -1 # UNKNOWN @property - def session_version(self) -> int: - return self.client.session_version + def protocol_version(self) -> int: + return self.client.protocol_version @property def client(self) -> TrezorClientDebugLink: @@ -1284,7 +1278,6 @@ class TrezorClientDebugLink(TrezorClient): # by the device. def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: - self._session_version: int = -1 try: debug_transport = transport.find_debug() self.debug = DebugLink(debug_transport, auto_interact) @@ -1311,14 +1304,6 @@ class TrezorClientDebugLink(TrezorClient): self.debug.version = self.version self.passphrase: str | None = None - @property - def session_version(self) -> int: - return self._session_version - - @session_version.setter - def session_version(self, value: int) -> None: - self._session_version = value - @property def layout_type(self) -> LayoutType: return self.debug.layout_type diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 1d0948ddf5..6b6f4cce2c 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -14,8 +14,6 @@ LOG = logging.getLogger(__name__) class Session: - CODEC_V1: t.Final[int] = 1 - THP_V2: t.Final[int] = 2 button_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None diff --git a/tests/device_tests/test_basic.py b/tests/device_tests/test_basic.py index 50dee4a42e..2955615e11 100644 --- a/tests/device_tests/test_basic.py +++ b/tests/device_tests/test_basic.py @@ -15,6 +15,7 @@ # If not, see . from trezorlib import device, messages, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client @@ -22,7 +23,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client def test_features(client: Client): session = client.get_session() f0 = session.features - if Session(session).session_version == Session.CODEC_V1: + if client.protocol_version == ProtocolVersion.PROTOCOL_V1: # session erases session_id from its features f0.session_id = session.id f1 = session.call(messages.Initialize(session_id=session.id)) diff --git a/tests/device_tests/test_debuglink.py b/tests/device_tests/test_debuglink.py index 7516da8a5d..d9445fddec 100644 --- a/tests/device_tests/test_debuglink.py +++ b/tests/device_tests/test_debuglink.py @@ -17,6 +17,7 @@ import pytest from trezorlib import debuglink, device, messages, misc +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path @@ -62,7 +63,7 @@ def test_pin(session: Session): @pytest.mark.models("core") def test_softlock_instability(session: Session): - if session.session_version == Session.THP_V2: + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: raise Exception("THIS NEEDS TO BE CHANGED FOR THP") def load_device(): diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 6179acb82f..18fde33506 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -19,6 +19,7 @@ from pathlib import Path import pytest from trezorlib import btc, device, exceptions, messages, misc, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path @@ -205,7 +206,7 @@ def test_apply_homescreen_toif(session: Session): @pytest.mark.models(skip=["legacy", "safe3"]) def test_apply_homescreen_jpeg(session: Session): - if session.session_version is Session.THP_V2: + if session.protocol_version is ProtocolVersion.PROTOCOL_V2: raise Exception( "FAILS BECAUSE THE MESSAGE IS BIGGER THAN THE INTERNAL READ BUFFER" ) diff --git a/tests/device_tests/test_protection_levels.py b/tests/device_tests/test_protection_levels.py index 1f49657e35..7825cbacee 100644 --- a/tests/device_tests/test_protection_levels.py +++ b/tests/device_tests/test_protection_levels.py @@ -17,6 +17,7 @@ import pytest from trezorlib import btc, device, messages, misc, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.exceptions import TrezorFailure @@ -61,17 +62,17 @@ def _assert_protection( client.refresh_features() assert client.features.pin_protection is pin assert client.features.passphrase_protection is passphrase - if session.session_version == Session.THP_V2: + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: new_session = session.client.get_session() session.lock() session.end() - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: new_session = session.client.get_session() return Session(new_session) def test_initialize(session: Session): - if session.session_version == Session.THP_V2: + if session.protocol_version == ProtocolVersion.PROTOCOL_V2: # Test is skipped for THP return @@ -194,7 +195,7 @@ def test_get_public_key(session: Session): client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PublicKey) @@ -208,7 +209,7 @@ def test_get_address(session: Session): with session, session.client as client: client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.Address) @@ -323,7 +324,7 @@ def test_sign_message(session: Session): expected_responses = [_pin_request(session)] - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) expected_responses.extend( @@ -409,7 +410,7 @@ def test_signtx(session: Session): with session, session.client as client: client.use_pin_sequence([PIN4]) expected_responses = [_pin_request(session)] - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: expected_responses.append(messages.PassphraseRequest) expected_responses.extend( [ @@ -463,11 +464,11 @@ def test_unlocked(session: Session): def test_passphrase_cached(session: Session): session = _assert_protection(session, pin=False) with session: - if session.session_version == 1: + if session.protocol_version == 1: session.set_expected_responses( [messages.PassphraseRequest, messages.Address] ) - elif session.session_version == 2: + elif session.protocol_version == 2: session.set_expected_responses([messages.Address]) else: raise Exception("Unknown session type") diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 45cec6acf2..56b8ace996 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -18,6 +18,7 @@ import pytest from trezorlib import cardano, messages, models from trezorlib.btc import get_public_node +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.exceptions import TrezorFailure @@ -33,7 +34,7 @@ PIN4 = "1234" def test_thp_end_session(client: Client): session = Session(client.get_session()) - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: # TODO: This test should be skipped on non-THP builds return diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 9dc059e978..6aa7dced5b 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -19,6 +19,7 @@ import random import pytest from trezorlib import device, exceptions, messages +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client @@ -65,7 +66,10 @@ def _get_xpub( ] else: expected_responses = [messages.PublicKey] - if passphrase_v1 is not None and session.session_version == Session.CODEC_V1: + if ( + passphrase_v1 is not None + and session.protocol_version == ProtocolVersion.PROTOCOL_V1 + ): session.passphrase = passphrase_v1 with session: diff --git a/tests/persistence_tests/test_shamir_persistence.py b/tests/persistence_tests/test_shamir_persistence.py index a545ee556c..1524bd5203 100644 --- a/tests/persistence_tests/test_shamir_persistence.py +++ b/tests/persistence_tests/test_shamir_persistence.py @@ -17,8 +17,8 @@ import pytest from trezorlib import device, messages +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import DebugLink, LayoutType -from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import RecoveryStatus from ..click_tests import common, recovery @@ -158,7 +158,7 @@ def test_recovery_on_old_wallet(core_emulator: Emulator): layout = debug.read_layout() # while keyboard is open, hit the device with Initialize/GetFeatures - if device_handler.client.session_version == Session.CODEC_V1: + if device_handler.client.protocol_version == ProtocolVersion.PROTOCOL_V1: device_handler.client.get_management_session().call(messages.Initialize()) device_handler.client.refresh_features() diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 2531fd450d..94682a9b19 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -21,6 +21,7 @@ import pytest from shamir_mnemonic import shamir from trezorlib import btc, debuglink, device, exceptions, fido, messages, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.messages import ( ApplySettings, @@ -373,7 +374,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): # Get a passphrase-less and a passphrased address. address = btc.get_address(session, "Bitcoin", PATH) - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: session.call(messages.Initialize(new_session=True)) new_session = emu.client.get_session(passphrase="TREZOR") address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index be8c7dc74b..bdeb74cabf 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -20,6 +20,7 @@ import pytest from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib._internal.emulator import Emulator +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.tools import parse_path @@ -139,7 +140,7 @@ def test_init_device(emulator: Emulator): btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest session_id = session.id - if session.session_version == Session.CODEC_V1: + if session.protocol_version == ProtocolVersion.PROTOCOL_V1: session.call(messages.Initialize(session_id=session_id)) btc.get_address( session,