From 33547c80f41c11d34f098b69c061af965f131673 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 8 Apr 2025 16:16:48 +0200 Subject: [PATCH] fix style and errors --- core/src/trezor/wire/__init__.py | 4 +-- core/src/trezor/wire/thp/channel.py | 4 +-- python/src/trezorlib/cli/__init__.py | 4 +-- python/src/trezorlib/client.py | 28 +++++++++++++------ python/src/trezorlib/debuglink.py | 15 +++------- python/src/trezorlib/transport/bridge.py | 6 ++-- tests/click_tests/test_reset_slip39_basic.py | 2 +- tests/conftest.py | 12 ++------ .../bitcoin/test_authorize_coinjoin.py | 2 +- tests/device_tests/thp/test_multiple_hosts.py | 2 ++ tests/ui_tests/__init__.py | 1 - 11 files changed, 39 insertions(+), 41 deletions(-) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index f5a4824d72..d5dcd485e4 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -102,11 +102,9 @@ if utils.USE_THP: return # pylint: disable=lost-exception else: - _PROTOBUF_BUFFER_SIZE = const(8192) - WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) async def handle_session(iface: WireInterface) -> None: - ctx = CodecContext(iface, WIRE_BUFFER) + ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER) next_msg: protocol_common.Message | None = None # Take a mark of modules that are imported at this point, so we can diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 287f3b3a43..80485c8320 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -161,8 +161,8 @@ class Channel: pass # TODO ?? if self.fallback_decrypt and self.expected_payload_length == self.bytes_read: self._finish_fallback() - from trezor.messages import Failure from trezor.enums import FailureType + from trezor.messages import Failure return self.write( Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"), @@ -201,7 +201,7 @@ class Channel: return None return self._handle_init_packet(packet) - def _handle_init_packet(self, packet: utils.BufferType) -> None: + def _handle_init_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: self.fallback_decrypt = False self.fallback_session_id = None self.bytes_read = 0 diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 11eaafe054..cc7e031b70 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -135,11 +135,11 @@ class TrezorConnection: # Try resume session from id if self.session_id is not None: - if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + if client.protocol_version is ProtocolVersion.V1: session = SessionV1.resume_from_id( client=client, session_id=self.session_id ) - elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: + elif client.protocol_version is ProtocolVersion.V2: session = SessionV2(client, self.session_id) # TODO fix resumption on THP else: diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 32b50d304e..3e18e77a54 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -95,7 +95,7 @@ class TrezorClient: if isinstance(self.protocol, ProtocolV1Channel): self._protocol_version = ProtocolVersion.V1 elif isinstance(self.protocol, ProtocolV2Channel): - self._protocol_version = ProtocolVersion.PROTOCOL_V2 + self._protocol_version = ProtocolVersion.V2 else: raise Exception("Unknown protocol version") @@ -129,7 +129,7 @@ class TrezorClient: from .transport.session import SessionV2 assert isinstance(passphrase, str) or passphrase is None - session_id = 1 # TODO fix this with ProtocolV2 session rework + session_id = b"\x01" # TODO fix this with ProtocolV2 session rework if session_id is not None: sid = int.from_bytes(session_id, "big") else: @@ -197,9 +197,9 @@ class TrezorClient: return protocol def reset_protocol(self): - if self._protocol_version == ProtocolVersion.PROTOCOL_V1: + if self._protocol_version == ProtocolVersion.V1: self.protocol = ProtocolV1Channel(self.transport, self.mapping) - elif self._protocol_version == ProtocolVersion.PROTOCOL_V2: + elif self._protocol_version == ProtocolVersion.V2: self.protocol = ProtocolV2Channel(self.transport, self.mapping) else: assert False @@ -217,11 +217,23 @@ class TrezorClient: else: raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - def _write(self, msg: t.Any) -> None: - self.protocol.write(msg) + def _write(self, msg: t.Any, session_id: int | None = None) -> None: + if isinstance(self.protocol, ProtocolV1Channel): + self.protocol.write(msg) + elif isinstance(self.protocol, ProtocolV2Channel): + assert session_id is not None + self.protocol.write(session_id=session_id, msg=msg) + else: + raise Exception("Unknown client protocol") - def _read(self) -> t.Any: - return self.protocol.read() + def _read(self, session_id: int | None = None) -> t.Any: + if isinstance(self.protocol, ProtocolV1Channel): + self.protocol.read() + elif isinstance(self.protocol, ProtocolV2Channel): + assert session_id is not None + self.protocol.read(session_id=session_id) + else: + raise Exception("Unknown client protocol") def get_default_client( diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 51ca52586e..d671d8d1ab 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -32,20 +32,13 @@ from pathlib import Path from mnemonic import Mnemonic from . import btc, mapping, messages, models, protobuf -from .client import ( - MAX_PASSPHRASE_LENGTH, - MAX_PIN_LENGTH, - PASSPHRASE_ON_DEVICE, - ProtocolVersion, - TrezorClient, -) -from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError +from .client import ProtocolVersion, TrezorClient +from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES -from .messages import Capability, DebugWaitType -from .protobuf import MessageType +from .messages import DebugWaitType from .tools import parse_path from .transport import Timeout -from .transport.session import Session, SessionV1, derive_seed +from .transport.session import Session from .transport.thp.protocol_v1 import ProtocolV1Channel if t.TYPE_CHECKING: diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index cf9f322f78..9ec65d523b 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -76,7 +76,7 @@ def detect_protocol_version(transport: "BridgeTransport") -> int: from .. import mapping, messages from ..messages import FailureType - protocol_version = ProtocolVersion.PROTOCOL_V1 + protocol_version = ProtocolVersion.V1 request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) transport.open() transport.write_chunk(request_type.to_bytes(2, "big") + request_data) @@ -87,13 +87,13 @@ def detect_protocol_version(transport: "BridgeTransport") -> int: if isinstance(response, messages.Failure): if response.code == FailureType.InvalidProtocol: LOG.debug("Protocol V2 detected") - protocol_version = ProtocolVersion.PROTOCOL_V2 + protocol_version = ProtocolVersion.V2 return protocol_version def _is_transport_valid(transport: "BridgeTransport") -> bool: - is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1 + is_valid = detect_protocol_version(transport) == ProtocolVersion.V1 if not is_valid: LOG.warning("Detected unsupported Bridge transport!") return is_valid diff --git a/tests/click_tests/test_reset_slip39_basic.py b/tests/click_tests/test_reset_slip39_basic.py index cb379beff2..fd601e0983 100644 --- a/tests/click_tests/test_reset_slip39_basic.py +++ b/tests/click_tests/test_reset_slip39_basic.py @@ -59,7 +59,7 @@ def test_reset_slip39_basic( entropy_check_count=0, _get_entropy=MOCK_GET_ENTROPY, ) - if device_handler.client.protocol_version is ProtocolVersion.PROTOCOL_V2: + if device_handler.client.protocol_version is ProtocolVersion.V2: reset.confirm_read(debug, middle_r=True) # confirm new wallet diff --git a/tests/conftest.py b/tests/conftest.py index 28d2bb75ee..87b719313d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -293,23 +293,17 @@ def _client_unlocked( args = protocol_marker.args protocol_version = _raw_client.protocol_version - if ( - protocol_version == ProtocolVersion.PROTOCOL_V1 - and "protocol_v1" not in args - ): + if protocol_version == ProtocolVersion.V1 and "protocol_v1" not in args: pytest.skip( f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." ) - if ( - protocol_version == ProtocolVersion.PROTOCOL_V2 - and "protocol_v2" not in args - ): + if protocol_version == ProtocolVersion.V2 and "protocol_v2" not in args: pytest.skip( f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." ) - if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2: + if _raw_client.protocol_version is ProtocolVersion.V2: pass sd_marker = request.node.get_closest_marker("sd_card") if sd_marker and not _raw_client.features.sd_card_present: diff --git a/tests/device_tests/bitcoin/test_authorize_coinjoin.py b/tests/device_tests/bitcoin/test_authorize_coinjoin.py index d865930a80..2e0c293caf 100644 --- a/tests/device_tests/bitcoin/test_authorize_coinjoin.py +++ b/tests/device_tests/bitcoin/test_authorize_coinjoin.py @@ -808,7 +808,7 @@ def test_multisession_authorization(client: Client): ) # Open a second session. - if client.protocol_version is ProtocolVersion.PROTOCOL_V2: + if client.protocol_version is ProtocolVersion.V2: session_id = b"\x02" else: session_id = None diff --git a/tests/device_tests/thp/test_multiple_hosts.py b/tests/device_tests/thp/test_multiple_hosts.py index 5e38d414ad..6f9d708b75 100644 --- a/tests/device_tests/thp/test_multiple_hosts.py +++ b/tests/device_tests/thp/test_multiple_hosts.py @@ -1,10 +1,12 @@ import os from time import sleep + import pytest from trezorlib import messages from trezorlib.client import ProtocolV2Channel from trezorlib.debuglink import TrezorClientDebugLink as Client + from ...conftest import LOCK_TIME pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client] diff --git a/tests/ui_tests/__init__.py b/tests/ui_tests/__init__.py index 912c6c2754..093ba2cac1 100644 --- a/tests/ui_tests/__init__.py +++ b/tests/ui_tests/__init__.py @@ -8,7 +8,6 @@ import pytest from _pytest.nodes import Node from _pytest.outcomes import Failed -from trezorlib.client import ProtocolVersion from trezorlib.debuglink import TrezorClientDebugLink as Client from . import common