1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-11 08:58:08 +00:00

chore(python): unify session and protocol versions under one IntEnum

[no changelog]
This commit is contained in:
M1nd3r 2024-11-29 11:11:00 +01:00
parent dbb0d44ca5
commit a29f21b2be
14 changed files with 49 additions and 51 deletions

View File

@ -26,7 +26,7 @@ from contextlib import contextmanager
import click import click
from .. import exceptions, transport, ui from .. import exceptions, transport, ui
from ..client import PROTOCOL_V2, TrezorClient from ..client import ProtocolVersion, TrezorClient
from ..messages import Capability from ..messages import Capability
from ..transport import Transport from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2 from ..transport.session import Session, SessionV1, SessionV2
@ -150,11 +150,11 @@ class TrezorConnection:
# Try resume session from id # Try resume session from id
if self.session_id is not None: if self.session_id is not None:
if client.protocol_version is Session.CODEC_V1: if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id( session = SessionV1.resume_from_id(
client=client, session_id=self.session_id client=client, session_id=self.session_id
) )
elif client.protocol_version is Session.THP_V2: elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id) session = SessionV2(client, self.session_id)
# TODO fix resumption on THP # TODO fix resumption on THP
else: else:
@ -311,7 +311,7 @@ def with_client(
try: try:
return func(client, *args, **kwargs) return func(client, *args, **kwargs)
finally: finally:
if client.protocol_version == PROTOCOL_V2: if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol) get_channel_db().save_channel(client.protocol)
# if not session_was_resumed: # if not session_was_resumed:
# try: # try:

View File

@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click import click
from .. import __version__, log, messages, protobuf from .. import __version__, log, messages, protobuf
from ..client import TrezorClient from ..client import ProtocolVersion, TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session from ..transport.session import Session
from ..transport.thp import channel_database from ..transport.thp import channel_database
@ -308,7 +308,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
try: try:
client = get_client(transport) client = get_client(transport)
description = format_device_name(client.features) description = format_device_name(client.features)
if client.protocol_version == Session.THP_V2: if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol) get_channel_db().save_channel(client.protocol)
except DeviceIsBusy: except DeviceIsBusy:
description = "Device is in use by another process" description = "Device is in use by another process"

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import logging import logging
import os import os
import typing as t import typing as t
from enum import IntEnum
from . import mapping, messages, models from . import mapping, messages, models
from .mapping import ProtobufMapping from .mapping import ProtobufMapping
@ -48,9 +49,11 @@ Or visit https://suite.trezor.io/
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
UNKNOWN = -1
PROTOCOL_V1 = 1 class ProtocolVersion(IntEnum):
PROTOCOL_V2 = 2 UNKNOWN = 0x00
PROTOCOL_V1 = 0x01 # Codec
PROTOCOL_V2 = 0x02 # THP
class TrezorClient: class TrezorClient:
@ -80,12 +83,13 @@ class TrezorClient:
else: else:
self.protocol = protocol self.protocol = protocol
self.protocol.mapping = self.mapping self.protocol.mapping = self.mapping
if isinstance(self.protocol, ProtocolV1): if isinstance(self.protocol, ProtocolV1):
self._protocol_version = PROTOCOL_V1 self._protocol_version = ProtocolVersion.PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2): elif isinstance(self.protocol, ProtocolV2):
self._protocol_version = PROTOCOL_V2 self._protocol_version = ProtocolVersion.PROTOCOL_V2
else: else:
self._protocol_version = UNKNOWN self._protocol_version = ProtocolVersion.UNKNOWN
@classmethod @classmethod
def resume( def resume(

View File

@ -42,7 +42,7 @@ from .exceptions import Cancelled, PinException, TrezorFailure
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType from .messages import Capability, DebugWaitType
from .tools import expect, parse_path from .tools import expect, parse_path
from .transport.session import Session, SessionV1, SessionV2 from .transport.session import Session, SessionV1
from .transport.thp.protocol_v1 import ProtocolV1 from .transport.thp.protocol_v1 import ProtocolV1
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
@ -1031,18 +1031,12 @@ class SessionDebugWrapper(Session):
def __init__(self, session: Session) -> None: def __init__(self, session: Session) -> None:
self._session = session self._session = session
self.reset_debug_features() self.reset_debug_features()
if isinstance(session, SessionV1): if isinstance(session, SessionDebugWrapper):
self.client.session_version = 1
elif isinstance(session, SessionV2):
self.client.session_version = 2
elif isinstance(session, SessionDebugWrapper):
raise Exception("Cannot wrap already wrapped session!") raise Exception("Cannot wrap already wrapped session!")
else:
self.client.session_version = -1 # UNKNOWN
@property @property
def session_version(self) -> int: def protocol_version(self) -> int:
return self.client.session_version return self.client.protocol_version
@property @property
def client(self) -> TrezorClientDebugLink: def client(self) -> TrezorClientDebugLink:
@ -1284,7 +1278,6 @@ class TrezorClientDebugLink(TrezorClient):
# by the device. # by the device.
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self._session_version: int = -1
try: try:
debug_transport = transport.find_debug() debug_transport = transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact) self.debug = DebugLink(debug_transport, auto_interact)
@ -1311,14 +1304,6 @@ class TrezorClientDebugLink(TrezorClient):
self.debug.version = self.version self.debug.version = self.version
self.passphrase: str | None = None self.passphrase: str | None = None
@property
def session_version(self) -> int:
return self._session_version
@session_version.setter
def session_version(self, value: int) -> None:
self._session_version = value
@property @property
def layout_type(self) -> LayoutType: def layout_type(self) -> LayoutType:
return self.debug.layout_type return self.debug.layout_type

View File

@ -14,8 +14,6 @@ LOG = logging.getLogger(__name__)
class Session: class Session:
CODEC_V1: t.Final[int] = 1
THP_V2: t.Final[int] = 2
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None

View File

@ -15,6 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from trezorlib import device, messages, models from trezorlib import device, messages, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
@ -22,7 +23,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
def test_features(client: Client): def test_features(client: Client):
session = client.get_session() session = client.get_session()
f0 = session.features f0 = session.features
if Session(session).session_version == Session.CODEC_V1: if client.protocol_version == ProtocolVersion.PROTOCOL_V1:
# session erases session_id from its features # session erases session_id from its features
f0.session_id = session.id f0.session_id = session.id
f1 = session.call(messages.Initialize(session_id=session.id)) f1 = session.call(messages.Initialize(session_id=session.id))

View File

@ -17,6 +17,7 @@
import pytest import pytest
from trezorlib import debuglink, device, messages, misc from trezorlib import debuglink, device, messages, misc
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
@ -62,7 +63,7 @@ def test_pin(session: Session):
@pytest.mark.models("core") @pytest.mark.models("core")
def test_softlock_instability(session: Session): def test_softlock_instability(session: Session):
if session.session_version == Session.THP_V2: if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
raise Exception("THIS NEEDS TO BE CHANGED FOR THP") raise Exception("THIS NEEDS TO BE CHANGED FOR THP")
def load_device(): def load_device():

View File

@ -19,6 +19,7 @@ from pathlib import Path
import pytest import pytest
from trezorlib import btc, device, exceptions, messages, misc, models from trezorlib import btc, device, exceptions, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
@ -205,7 +206,7 @@ def test_apply_homescreen_toif(session: Session):
@pytest.mark.models(skip=["legacy", "safe3"]) @pytest.mark.models(skip=["legacy", "safe3"])
def test_apply_homescreen_jpeg(session: Session): def test_apply_homescreen_jpeg(session: Session):
if session.session_version is Session.THP_V2: if session.protocol_version is ProtocolVersion.PROTOCOL_V2:
raise Exception( raise Exception(
"FAILS BECAUSE THE MESSAGE IS BIGGER THAN THE INTERNAL READ BUFFER" "FAILS BECAUSE THE MESSAGE IS BIGGER THAN THE INTERNAL READ BUFFER"
) )

View File

@ -17,6 +17,7 @@
import pytest import pytest
from trezorlib import btc, device, messages, misc, models from trezorlib import btc, device, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
@ -61,17 +62,17 @@ def _assert_protection(
client.refresh_features() client.refresh_features()
assert client.features.pin_protection is pin assert client.features.pin_protection is pin
assert client.features.passphrase_protection is passphrase assert client.features.passphrase_protection is passphrase
if session.session_version == Session.THP_V2: if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
new_session = session.client.get_session() new_session = session.client.get_session()
session.lock() session.lock()
session.end() session.end()
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
new_session = session.client.get_session() new_session = session.client.get_session()
return Session(new_session) return Session(new_session)
def test_initialize(session: Session): def test_initialize(session: Session):
if session.session_version == Session.THP_V2: if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
# Test is skipped for THP # Test is skipped for THP
return return
@ -194,7 +195,7 @@ def test_get_public_key(session: Session):
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PassphraseRequest)
expected_responses.append(messages.PublicKey) expected_responses.append(messages.PublicKey)
@ -208,7 +209,7 @@ def test_get_address(session: Session):
with session, session.client as client: with session, session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PassphraseRequest)
expected_responses.append(messages.Address) expected_responses.append(messages.Address)
@ -323,7 +324,7 @@ def test_sign_message(session: Session):
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PassphraseRequest)
expected_responses.extend( expected_responses.extend(
@ -409,7 +410,7 @@ def test_signtx(session: Session):
with session, session.client as client: with session, session.client as client:
client.use_pin_sequence([PIN4]) client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)] expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest) expected_responses.append(messages.PassphraseRequest)
expected_responses.extend( expected_responses.extend(
[ [
@ -463,11 +464,11 @@ def test_unlocked(session: Session):
def test_passphrase_cached(session: Session): def test_passphrase_cached(session: Session):
session = _assert_protection(session, pin=False) session = _assert_protection(session, pin=False)
with session: with session:
if session.session_version == 1: if session.protocol_version == 1:
session.set_expected_responses( session.set_expected_responses(
[messages.PassphraseRequest, messages.Address] [messages.PassphraseRequest, messages.Address]
) )
elif session.session_version == 2: elif session.protocol_version == 2:
session.set_expected_responses([messages.Address]) session.set_expected_responses([messages.Address])
else: else:
raise Exception("Unknown session type") raise Exception("Unknown session type")

View File

@ -18,6 +18,7 @@ import pytest
from trezorlib import cardano, messages, models from trezorlib import cardano, messages, models
from trezorlib.btc import get_public_node from trezorlib.btc import get_public_node
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure from trezorlib.exceptions import TrezorFailure
@ -33,7 +34,7 @@ PIN4 = "1234"
def test_thp_end_session(client: Client): def test_thp_end_session(client: Client):
session = Session(client.get_session()) session = Session(client.get_session())
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
# TODO: This test should be skipped on non-THP builds # TODO: This test should be skipped on non-THP builds
return return

View File

@ -19,6 +19,7 @@ import random
import pytest import pytest
from trezorlib import device, exceptions, messages from trezorlib import device, exceptions, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
@ -65,7 +66,10 @@ def _get_xpub(
] ]
else: else:
expected_responses = [messages.PublicKey] expected_responses = [messages.PublicKey]
if passphrase_v1 is not None and session.session_version == Session.CODEC_V1: if (
passphrase_v1 is not None
and session.protocol_version == ProtocolVersion.PROTOCOL_V1
):
session.passphrase = passphrase_v1 session.passphrase = passphrase_v1
with session: with session:

View File

@ -17,8 +17,8 @@
import pytest import pytest
from trezorlib import device, messages from trezorlib import device, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import DebugLink, LayoutType from trezorlib.debuglink import DebugLink, LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.messages import RecoveryStatus from trezorlib.messages import RecoveryStatus
from ..click_tests import common, recovery from ..click_tests import common, recovery
@ -158,7 +158,7 @@ def test_recovery_on_old_wallet(core_emulator: Emulator):
layout = debug.read_layout() layout = debug.read_layout()
# while keyboard is open, hit the device with Initialize/GetFeatures # while keyboard is open, hit the device with Initialize/GetFeatures
if device_handler.client.session_version == Session.CODEC_V1: if device_handler.client.protocol_version == ProtocolVersion.PROTOCOL_V1:
device_handler.client.get_management_session().call(messages.Initialize()) device_handler.client.get_management_session().call(messages.Initialize())
device_handler.client.refresh_features() device_handler.client.refresh_features()

View File

@ -21,6 +21,7 @@ import pytest
from shamir_mnemonic import shamir from shamir_mnemonic import shamir
from trezorlib import btc, debuglink, device, exceptions, fido, messages, models from trezorlib import btc, debuglink, device, exceptions, fido, messages, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.messages import ( from trezorlib.messages import (
ApplySettings, ApplySettings,
@ -373,7 +374,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
# Get a passphrase-less and a passphrased address. # Get a passphrase-less and a passphrased address.
address = btc.get_address(session, "Bitcoin", PATH) address = btc.get_address(session, "Bitcoin", PATH)
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(new_session=True)) session.call(messages.Initialize(new_session=True))
new_session = emu.client.get_session(passphrase="TREZOR") new_session = emu.client.get_session(passphrase="TREZOR")
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)

View File

@ -20,6 +20,7 @@ import pytest
from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib._internal.emulator import Emulator from trezorlib._internal.emulator import Emulator
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
@ -139,7 +140,7 @@ def test_init_device(emulator: Emulator):
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0")) btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
# in TT < 2.3.0 session_id will only be available after PassphraseStateRequest # in TT < 2.3.0 session_id will only be available after PassphraseStateRequest
session_id = session.id session_id = session.id
if session.session_version == Session.CODEC_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(session_id=session_id)) session.call(messages.Initialize(session_id=session_id))
btc.get_address( btc.get_address(
session, session,