mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-02 20:48:30 +00:00
chore(python): unify session and protocol versions under one IntEnum
[no changelog]
This commit is contained in:
parent
dbb0d44ca5
commit
a29f21b2be
@ -26,7 +26,7 @@ from contextlib import contextmanager
|
||||
import click
|
||||
|
||||
from .. import exceptions, transport, ui
|
||||
from ..client import PROTOCOL_V2, TrezorClient
|
||||
from ..client import ProtocolVersion, TrezorClient
|
||||
from ..messages import Capability
|
||||
from ..transport import Transport
|
||||
from ..transport.session import Session, SessionV1, SessionV2
|
||||
@ -150,11 +150,11 @@ class TrezorConnection:
|
||||
|
||||
# Try resume session from id
|
||||
if self.session_id is not None:
|
||||
if client.protocol_version is Session.CODEC_V1:
|
||||
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
||||
session = SessionV1.resume_from_id(
|
||||
client=client, session_id=self.session_id
|
||||
)
|
||||
elif client.protocol_version is Session.THP_V2:
|
||||
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
session = SessionV2(client, self.session_id)
|
||||
# TODO fix resumption on THP
|
||||
else:
|
||||
@ -311,7 +311,7 @@ def with_client(
|
||||
try:
|
||||
return func(client, *args, **kwargs)
|
||||
finally:
|
||||
if client.protocol_version == PROTOCOL_V2:
|
||||
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
get_channel_db().save_channel(client.protocol)
|
||||
# if not session_was_resumed:
|
||||
# try:
|
||||
|
@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca
|
||||
import click
|
||||
|
||||
from .. import __version__, log, messages, protobuf
|
||||
from ..client import TrezorClient
|
||||
from ..client import ProtocolVersion, TrezorClient
|
||||
from ..transport import DeviceIsBusy, enumerate_devices
|
||||
from ..transport.session import Session
|
||||
from ..transport.thp import channel_database
|
||||
@ -308,7 +308,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
|
||||
try:
|
||||
client = get_client(transport)
|
||||
description = format_device_name(client.features)
|
||||
if client.protocol_version == Session.THP_V2:
|
||||
if client.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
get_channel_db().save_channel(client.protocol)
|
||||
except DeviceIsBusy:
|
||||
description = "Device is in use by another process"
|
||||
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
from enum import IntEnum
|
||||
|
||||
from . import mapping, messages, models
|
||||
from .mapping import ProtobufMapping
|
||||
@ -48,9 +49,11 @@ Or visit https://suite.trezor.io/
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
UNKNOWN = -1
|
||||
PROTOCOL_V1 = 1
|
||||
PROTOCOL_V2 = 2
|
||||
|
||||
class ProtocolVersion(IntEnum):
|
||||
UNKNOWN = 0x00
|
||||
PROTOCOL_V1 = 0x01 # Codec
|
||||
PROTOCOL_V2 = 0x02 # THP
|
||||
|
||||
|
||||
class TrezorClient:
|
||||
@ -80,12 +83,13 @@ class TrezorClient:
|
||||
else:
|
||||
self.protocol = protocol
|
||||
self.protocol.mapping = self.mapping
|
||||
|
||||
if isinstance(self.protocol, ProtocolV1):
|
||||
self._protocol_version = PROTOCOL_V1
|
||||
self._protocol_version = ProtocolVersion.PROTOCOL_V1
|
||||
elif isinstance(self.protocol, ProtocolV2):
|
||||
self._protocol_version = PROTOCOL_V2
|
||||
self._protocol_version = ProtocolVersion.PROTOCOL_V2
|
||||
else:
|
||||
self._protocol_version = UNKNOWN
|
||||
self._protocol_version = ProtocolVersion.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def resume(
|
||||
|
@ -42,7 +42,7 @@ from .exceptions import Cancelled, PinException, TrezorFailure
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability, DebugWaitType
|
||||
from .tools import expect, parse_path
|
||||
from .transport.session import Session, SessionV1, SessionV2
|
||||
from .transport.session import Session, SessionV1
|
||||
from .transport.thp.protocol_v1 import ProtocolV1
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@ -1031,18 +1031,12 @@ class SessionDebugWrapper(Session):
|
||||
def __init__(self, session: Session) -> None:
|
||||
self._session = session
|
||||
self.reset_debug_features()
|
||||
if isinstance(session, SessionV1):
|
||||
self.client.session_version = 1
|
||||
elif isinstance(session, SessionV2):
|
||||
self.client.session_version = 2
|
||||
elif isinstance(session, SessionDebugWrapper):
|
||||
if isinstance(session, SessionDebugWrapper):
|
||||
raise Exception("Cannot wrap already wrapped session!")
|
||||
else:
|
||||
self.client.session_version = -1 # UNKNOWN
|
||||
|
||||
@property
|
||||
def session_version(self) -> int:
|
||||
return self.client.session_version
|
||||
def protocol_version(self) -> int:
|
||||
return self.client.protocol_version
|
||||
|
||||
@property
|
||||
def client(self) -> TrezorClientDebugLink:
|
||||
@ -1284,7 +1278,6 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
# by the device.
|
||||
|
||||
def __init__(self, transport: "Transport", auto_interact: bool = True) -> None:
|
||||
self._session_version: int = -1
|
||||
try:
|
||||
debug_transport = transport.find_debug()
|
||||
self.debug = DebugLink(debug_transport, auto_interact)
|
||||
@ -1311,14 +1304,6 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
self.debug.version = self.version
|
||||
self.passphrase: str | None = None
|
||||
|
||||
@property
|
||||
def session_version(self) -> int:
|
||||
return self._session_version
|
||||
|
||||
@session_version.setter
|
||||
def session_version(self, value: int) -> None:
|
||||
self._session_version = value
|
||||
|
||||
@property
|
||||
def layout_type(self) -> LayoutType:
|
||||
return self.debug.layout_type
|
||||
|
@ -14,8 +14,6 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Session:
|
||||
CODEC_V1: t.Final[int] = 1
|
||||
THP_V2: t.Final[int] = 2
|
||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
|
@ -15,6 +15,7 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from trezorlib import device, messages, models
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
|
||||
@ -22,7 +23,7 @@ from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
def test_features(client: Client):
|
||||
session = client.get_session()
|
||||
f0 = session.features
|
||||
if Session(session).session_version == Session.CODEC_V1:
|
||||
if client.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
# session erases session_id from its features
|
||||
f0.session_id = session.id
|
||||
f1 = session.call(messages.Initialize(session_id=session.id))
|
||||
|
@ -17,6 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import debuglink, device, messages, misc
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.tools import parse_path
|
||||
@ -62,7 +63,7 @@ def test_pin(session: Session):
|
||||
|
||||
@pytest.mark.models("core")
|
||||
def test_softlock_instability(session: Session):
|
||||
if session.session_version == Session.THP_V2:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
raise Exception("THIS NEEDS TO BE CHANGED FOR THP")
|
||||
|
||||
def load_device():
|
||||
|
@ -19,6 +19,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, device, exceptions, messages, misc, models
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
@ -205,7 +206,7 @@ def test_apply_homescreen_toif(session: Session):
|
||||
|
||||
@pytest.mark.models(skip=["legacy", "safe3"])
|
||||
def test_apply_homescreen_jpeg(session: Session):
|
||||
if session.session_version is Session.THP_V2:
|
||||
if session.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
raise Exception(
|
||||
"FAILS BECAUSE THE MESSAGE IS BIGGER THAN THE INTERNAL READ BUFFER"
|
||||
)
|
||||
|
@ -17,6 +17,7 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import btc, device, messages, misc, models
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import LayoutType
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
@ -61,17 +62,17 @@ def _assert_protection(
|
||||
client.refresh_features()
|
||||
assert client.features.pin_protection is pin
|
||||
assert client.features.passphrase_protection is passphrase
|
||||
if session.session_version == Session.THP_V2:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
new_session = session.client.get_session()
|
||||
session.lock()
|
||||
session.end()
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
new_session = session.client.get_session()
|
||||
return Session(new_session)
|
||||
|
||||
|
||||
def test_initialize(session: Session):
|
||||
if session.session_version == Session.THP_V2:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
# Test is skipped for THP
|
||||
return
|
||||
|
||||
@ -194,7 +195,7 @@ def test_get_public_key(session: Session):
|
||||
client.use_pin_sequence([PIN4])
|
||||
expected_responses = [_pin_request(session)]
|
||||
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
expected_responses.append(messages.PassphraseRequest)
|
||||
expected_responses.append(messages.PublicKey)
|
||||
|
||||
@ -208,7 +209,7 @@ def test_get_address(session: Session):
|
||||
with session, session.client as client:
|
||||
client.use_pin_sequence([PIN4])
|
||||
expected_responses = [_pin_request(session)]
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
expected_responses.append(messages.PassphraseRequest)
|
||||
expected_responses.append(messages.Address)
|
||||
|
||||
@ -323,7 +324,7 @@ def test_sign_message(session: Session):
|
||||
|
||||
expected_responses = [_pin_request(session)]
|
||||
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
expected_responses.append(messages.PassphraseRequest)
|
||||
|
||||
expected_responses.extend(
|
||||
@ -409,7 +410,7 @@ def test_signtx(session: Session):
|
||||
with session, session.client as client:
|
||||
client.use_pin_sequence([PIN4])
|
||||
expected_responses = [_pin_request(session)]
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
expected_responses.append(messages.PassphraseRequest)
|
||||
expected_responses.extend(
|
||||
[
|
||||
@ -463,11 +464,11 @@ def test_unlocked(session: Session):
|
||||
def test_passphrase_cached(session: Session):
|
||||
session = _assert_protection(session, pin=False)
|
||||
with session:
|
||||
if session.session_version == 1:
|
||||
if session.protocol_version == 1:
|
||||
session.set_expected_responses(
|
||||
[messages.PassphraseRequest, messages.Address]
|
||||
)
|
||||
elif session.session_version == 2:
|
||||
elif session.protocol_version == 2:
|
||||
session.set_expected_responses([messages.Address])
|
||||
else:
|
||||
raise Exception("Unknown session type")
|
||||
|
@ -18,6 +18,7 @@ import pytest
|
||||
|
||||
from trezorlib import cardano, messages, models
|
||||
from trezorlib.btc import get_public_node
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.exceptions import TrezorFailure
|
||||
@ -33,7 +34,7 @@ PIN4 = "1234"
|
||||
|
||||
def test_thp_end_session(client: Client):
|
||||
session = Session(client.get_session())
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
# TODO: This test should be skipped on non-THP builds
|
||||
return
|
||||
|
||||
|
@ -19,6 +19,7 @@ import random
|
||||
import pytest
|
||||
|
||||
from trezorlib import device, exceptions, messages
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import LayoutType
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
@ -65,7 +66,10 @@ def _get_xpub(
|
||||
]
|
||||
else:
|
||||
expected_responses = [messages.PublicKey]
|
||||
if passphrase_v1 is not None and session.session_version == Session.CODEC_V1:
|
||||
if (
|
||||
passphrase_v1 is not None
|
||||
and session.protocol_version == ProtocolVersion.PROTOCOL_V1
|
||||
):
|
||||
session.passphrase = passphrase_v1
|
||||
|
||||
with session:
|
||||
|
@ -17,8 +17,8 @@
|
||||
import pytest
|
||||
|
||||
from trezorlib import device, messages
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import DebugLink, LayoutType
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.messages import RecoveryStatus
|
||||
|
||||
from ..click_tests import common, recovery
|
||||
@ -158,7 +158,7 @@ def test_recovery_on_old_wallet(core_emulator: Emulator):
|
||||
layout = debug.read_layout()
|
||||
|
||||
# while keyboard is open, hit the device with Initialize/GetFeatures
|
||||
if device_handler.client.session_version == Session.CODEC_V1:
|
||||
if device_handler.client.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
device_handler.client.get_management_session().call(messages.Initialize())
|
||||
device_handler.client.refresh_features()
|
||||
|
||||
|
@ -21,6 +21,7 @@ import pytest
|
||||
from shamir_mnemonic import shamir
|
||||
|
||||
from trezorlib import btc, debuglink, device, exceptions, fido, messages, models
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.messages import (
|
||||
ApplySettings,
|
||||
@ -373,7 +374,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
|
||||
|
||||
# Get a passphrase-less and a passphrased address.
|
||||
address = btc.get_address(session, "Bitcoin", PATH)
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
session.call(messages.Initialize(new_session=True))
|
||||
new_session = emu.client.get_session(passphrase="TREZOR")
|
||||
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)
|
||||
|
@ -20,6 +20,7 @@ import pytest
|
||||
|
||||
from trezorlib import btc, device, mapping, messages, models, protobuf
|
||||
from trezorlib._internal.emulator import Emulator
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import SessionDebugWrapper as Session
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
@ -139,7 +140,7 @@ def test_init_device(emulator: Emulator):
|
||||
btc.get_address(session, "Testnet", parse_path("44h/1h/0h/0/0"))
|
||||
# in TT < 2.3.0 session_id will only be available after PassphraseStateRequest
|
||||
session_id = session.id
|
||||
if session.session_version == Session.CODEC_V1:
|
||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
session.call(messages.Initialize(session_id=session_id))
|
||||
btc.get_address(
|
||||
session,
|
||||
|
Loading…
Reference in New Issue
Block a user