1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-02 20:48:30 +00:00

chore(python): unify session and protocol versions under one IntEnum

[no changelog]
This commit is contained in:
M1nd3r 2024-11-29 11:11:00 +01:00
parent dbb0d44ca5
commit a29f21b2be
14 changed files with 49 additions and 51 deletions

View File

@ -26,7 +26,7 @@ from contextlib import contextmanager
import click
from .. import exceptions, transport, ui
from ..client import PROTOCOL_V2, TrezorClient
from ..client import ProtocolVersion, TrezorClient
from ..messages import Capability
from ..transport import Transport
from ..transport.session import Session, SessionV1, SessionV2
@ -150,11 +150,11 @@ class TrezorConnection:
# Try resume session from id
if self.session_id is not None:
if client.protocol_version is Session.CODEC_V1:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is Session.THP_V2:
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:
@ -311,7 +311,7 @@ def with_client(
try:
return func(client, *args, **kwargs)
finally:
if client.protocol_version == PROTOCOL_V2:
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
# if not session_was_resumed:
# try:

View File

@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
import click
from .. import __version__, log, messages, protobuf
from ..client import TrezorClient
from ..client import ProtocolVersion, TrezorClient
from ..transport import DeviceIsBusy, enumerate_devices
from ..transport.session import Session
from ..transport.thp import channel_database
@ -308,7 +308,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
try:
client = get_client(transport)
description = format_device_name(client.features)
if client.protocol_version == Session.THP_V2:
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
get_channel_db().save_channel(client.protocol)
except DeviceIsBusy:
description = "Device is in use by another process"

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import logging
import os
import typing as t
from enum import IntEnum
from . import mapping, messages, models
from .mapping import ProtobufMapping
@ -48,9 +49,11 @@ Or visit https://suite.trezor.io/
LOG = logging.getLogger(__name__)
UNKNOWN = -1
PROTOCOL_V1 = 1
PROTOCOL_V2 = 2
class ProtocolVersion(IntEnum):
UNKNOWN = 0x00
PROTOCOL_V1 = 0x01 # Codec
PROTOCOL_V2 = 0x02 # THP
class TrezorClient:
@ -80,12 +83,13 @@ class TrezorClient:
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
if isinstance(self.protocol, ProtocolV1):
self._protocol_version = PROTOCOL_V1
self._protocol_version = ProtocolVersion.PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2):
self._protocol_version = PROTOCOL_V2
self._protocol_version = ProtocolVersion.PROTOCOL_V2
else:
self._protocol_version = UNKNOWN
self._protocol_version = ProtocolVersion.UNKNOWN
@classmethod
def resume(

View File

@ -42,7 +42,7 @@ from .exceptions import Cancelled, PinException, TrezorFailure
from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType
from .tools import expect, parse_path
from .transport.session import Session, SessionV1, SessionV2
from .transport.session import Session, SessionV1
from .transport.thp.protocol_v1 import ProtocolV1
if t.TYPE_CHECKING:
@ -1031,18 +1031,12 @@ class SessionDebugWrapper(Session):
def __init__(self, session: Session) -> None:
self._session = session
self.reset_debug_features()
if isinstance(session, SessionV1):
self.client.session_version = 1
elif isinstance(session, SessionV2):
self.client.session_version = 2
elif isinstance(session, SessionDebugWrapper):
if isinstance(session, SessionDebugWrapper):
raise Exception("Cannot wrap already wrapped session!")
else:
self.client.session_version = -1 # UNKNOWN
@property
def session_version(self) -> int:
return self.client.session_version
def protocol_version(self) -> int:
return self.client.protocol_version
@property
def client(self) -> TrezorClientDebugLink:
@ -1284,7 +1278,6 @@ class TrezorClientDebugLink(TrezorClient):
# by the device.
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
self._session_version: int = -1
try:
debug_transport = transport.find_debug()
self.debug = DebugLink(debug_transport, auto_interact)
@ -1311,14 +1304,6 @@ class TrezorClientDebugLink(TrezorClient):
self.debug.version = self.version
self.passphrase: str | None = None
@property
def session_version(self) -> int:
return self._session_version
@session_version.setter
def session_version(self, value: int) -> None:
self._session_version = value
@property
def layout_type(self) -> LayoutType:
return self.debug.layout_type

View File

@ -14,8 +14,6 @@ LOG = logging.getLogger(__name__)
class Session:
CODEC_V1: t.Final[int] = 1
THP_V2: t.Final[int] = 2
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None

View File

@ -15,6 +15,7 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from trezorlib import device, messages, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
@ -22,7 +23,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
def test_features(client: Client):
session = client.get_session()
f0 = session.features
if Session(session).session_version == Session.CODEC_V1:
if client.protocol_version == ProtocolVersion.PROTOCOL_V1:
# session erases session_id from its features
f0.session_id = session.id
f1 = session.call(messages.Initialize(session_id=session.id))

View File

@ -17,6 +17,7 @@
import pytest
from trezorlib import debuglink, device, messages, misc
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.tools import parse_path
@ -62,7 +63,7 @@ def test_pin(session: Session):
@pytest.mark.models("core")
def test_softlock_instability(session: Session):
if session.session_version == Session.THP_V2:
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
raise Exception("THIS NEEDS TO BE CHANGED FOR THP")
def load_device():

View File

@ -19,6 +19,7 @@ from pathlib import Path
import pytest
from trezorlib import btc, device, exceptions, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
@ -205,7 +206,7 @@ def test_apply_homescreen_toif(session: Session):
@pytest.mark.models(skip=["legacy", "safe3"])
def test_apply_homescreen_jpeg(session: Session):
if session.session_version is Session.THP_V2:
if session.protocol_version is ProtocolVersion.PROTOCOL_V2:
raise Exception(
"FAILS BECAUSE THE MESSAGE IS BIGGER THAN THE INTERNAL READ BUFFER"
)

View File

@ -17,6 +17,7 @@
import pytest
from trezorlib import btc, device, messages, misc, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.exceptions import TrezorFailure
@ -61,17 +62,17 @@ def _assert_protection(
client.refresh_features()
assert client.features.pin_protection is pin
assert client.features.passphrase_protection is passphrase
if session.session_version == Session.THP_V2:
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
new_session = session.client.get_session()
session.lock()
session.end()
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
new_session = session.client.get_session()
return Session(new_session)
def test_initialize(session: Session):
if session.session_version == Session.THP_V2:
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
# Test is skipped for THP
return
@ -194,7 +195,7 @@ def test_get_public_key(session: Session):
client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
expected_responses.append(messages.PublicKey)
@ -208,7 +209,7 @@ def test_get_address(session: Session):
with session, session.client as client:
client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
expected_responses.append(messages.Address)
@ -323,7 +324,7 @@ def test_sign_message(session: Session):
expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
expected_responses.extend(
@ -409,7 +410,7 @@ def test_signtx(session: Session):
with session, session.client as client:
client.use_pin_sequence([PIN4])
expected_responses = [_pin_request(session)]
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
expected_responses.append(messages.PassphraseRequest)
expected_responses.extend(
[
@ -463,11 +464,11 @@ def test_unlocked(session: Session):
def test_passphrase_cached(session: Session):
session = _assert_protection(session, pin=False)
with session:
if session.session_version == 1:
if session.protocol_version == 1:
session.set_expected_responses(
[messages.PassphraseRequest, messages.Address]
)
elif session.session_version == 2:
elif session.protocol_version == 2:
session.set_expected_responses([messages.Address])
else:
raise Exception("Unknown session type")

View File

@ -18,6 +18,7 @@ import pytest
from trezorlib import cardano, messages, models
from trezorlib.btc import get_public_node
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.exceptions import TrezorFailure
@ -33,7 +34,7 @@ PIN4 = "1234"
def test_thp_end_session(client: Client):
session = Session(client.get_session())
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
# TODO: This test should be skipped on non-THP builds
return

View File

@ -19,6 +19,7 @@ import random
import pytest
from trezorlib import device, exceptions, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client
@ -65,7 +66,10 @@ def _get_xpub(
]
else:
expected_responses = [messages.PublicKey]
if passphrase_v1 is not None and session.session_version == Session.CODEC_V1:
if (
passphrase_v1 is not None
and session.protocol_version == ProtocolVersion.PROTOCOL_V1
):
session.passphrase = passphrase_v1
with session:

View File

@ -17,8 +17,8 @@
import pytest
from trezorlib import device, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import DebugLink, LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.messages import RecoveryStatus
from ..click_tests import common, recovery
@ -158,7 +158,7 @@ def test_recovery_on_old_wallet(core_emulator: Emulator):
layout = debug.read_layout()
# while keyboard is open, hit the device with Initialize/GetFeatures
if device_handler.client.session_version == Session.CODEC_V1:
if device_handler.client.protocol_version == ProtocolVersion.PROTOCOL_V1:
device_handler.client.get_management_session().call(messages.Initialize())
device_handler.client.refresh_features()

View File

@ -21,6 +21,7 @@ import pytest
from shamir_mnemonic import shamir
from trezorlib import btc, debuglink, device, exceptions, fido, messages, models
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.messages import (
ApplySettings,
@ -373,7 +374,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
# Get a passphrase-less and a passphrased address.
address = btc.get_address(session, "Bitcoin", PATH)
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(new_session=True))
new_session = emu.client.get_session(passphrase="TREZOR")
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)

View File

@ -20,6 +20,7 @@ import pytest
from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib._internal.emulator import Emulator
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.tools import parse_path
@ -139,7 +140,7 @@ def test_init_device(emulator: Emulator):
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
# in TT < 2.3.0 session_id will only be available after PassphraseStateRequest
session_id = session.id
if session.session_version == Session.CODEC_V1:
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(session_id=session_id))
btc.get_address(
session,