1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-15 06:45:59 +00:00

fix style and errors

This commit is contained in:
M1nd3r 2025-04-08 16:16:48 +02:00
parent 68ac7fb0ea
commit 33547c80f4
11 changed files with 39 additions and 41 deletions

View File

@ -102,11 +102,9 @@ if utils.USE_THP:
return # pylint: disable=lost-exception
else:
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
async def handle_session(iface: WireInterface) -> None:
ctx = CodecContext(iface, WIRE_BUFFER)
ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can

View File

@ -161,8 +161,8 @@ class Channel:
pass # TODO ??
if self.fallback_decrypt and self.expected_payload_length == self.bytes_read:
self._finish_fallback()
from trezor.messages import Failure
from trezor.enums import FailureType
from trezor.messages import Failure
return self.write(
Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"),
@ -201,7 +201,7 @@ class Channel:
return None
return self._handle_init_packet(packet)
def _handle_init_packet(self, packet: utils.BufferType) -> None:
def _handle_init_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
self.fallback_decrypt = False
self.fallback_session_id = None
self.bytes_read = 0

View File

@ -135,11 +135,11 @@ class TrezorConnection:
# Try resume session from id
if self.session_id is not None:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
if client.protocol_version is ProtocolVersion.V1:
session = SessionV1.resume_from_id(
client=client, session_id=self.session_id
)
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
elif client.protocol_version is ProtocolVersion.V2:
session = SessionV2(client, self.session_id)
# TODO fix resumption on THP
else:

View File

@ -95,7 +95,7 @@ class TrezorClient:
if isinstance(self.protocol, ProtocolV1Channel):
self._protocol_version = ProtocolVersion.V1
elif isinstance(self.protocol, ProtocolV2Channel):
self._protocol_version = ProtocolVersion.PROTOCOL_V2
self._protocol_version = ProtocolVersion.V2
else:
raise Exception("Unknown protocol version")
@ -129,7 +129,7 @@ class TrezorClient:
from .transport.session import SessionV2
assert isinstance(passphrase, str) or passphrase is None
session_id = 1 # TODO fix this with ProtocolV2 session rework
session_id = b"\x01" # TODO fix this with ProtocolV2 session rework
if session_id is not None:
sid = int.from_bytes(session_id, "big")
else:
@ -197,9 +197,9 @@ class TrezorClient:
return protocol
def reset_protocol(self):
if self._protocol_version == ProtocolVersion.PROTOCOL_V1:
if self._protocol_version == ProtocolVersion.V1:
self.protocol = ProtocolV1Channel(self.transport, self.mapping)
elif self._protocol_version == ProtocolVersion.PROTOCOL_V2:
elif self._protocol_version == ProtocolVersion.V2:
self.protocol = ProtocolV2Channel(self.transport, self.mapping)
else:
assert False
@ -217,11 +217,23 @@ class TrezorClient:
else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
def _write(self, msg: t.Any) -> None:
self.protocol.write(msg)
def _write(self, msg: t.Any, session_id: int | None = None) -> None:
if isinstance(self.protocol, ProtocolV1Channel):
self.protocol.write(msg)
elif isinstance(self.protocol, ProtocolV2Channel):
assert session_id is not None
self.protocol.write(session_id=session_id, msg=msg)
else:
raise Exception("Unknown client protocol")
def _read(self) -> t.Any:
return self.protocol.read()
def _read(self, session_id: int | None = None) -> t.Any:
if isinstance(self.protocol, ProtocolV1Channel):
self.protocol.read()
elif isinstance(self.protocol, ProtocolV2Channel):
assert session_id is not None
self.protocol.read(session_id=session_id)
else:
raise Exception("Unknown client protocol")
def get_default_client(

View File

@ -32,20 +32,13 @@ from pathlib import Path
from mnemonic import Mnemonic
from . import btc, mapping, messages, models, protobuf
from .client import (
MAX_PASSPHRASE_LENGTH,
MAX_PIN_LENGTH,
PASSPHRASE_ON_DEVICE,
ProtocolVersion,
TrezorClient,
)
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
from .client import ProtocolVersion, TrezorClient
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType
from .protobuf import MessageType
from .messages import DebugWaitType
from .tools import parse_path
from .transport import Timeout
from .transport.session import Session, SessionV1, derive_seed
from .transport.session import Session
from .transport.thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING:

View File

@ -76,7 +76,7 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
from .. import mapping, messages
from ..messages import FailureType
protocol_version = ProtocolVersion.PROTOCOL_V1
protocol_version = ProtocolVersion.V1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
transport.open()
transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
@ -87,13 +87,13 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
if isinstance(response, messages.Failure):
if response.code == FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol_version = ProtocolVersion.PROTOCOL_V2
protocol_version = ProtocolVersion.V2
return protocol_version
def _is_transport_valid(transport: "BridgeTransport") -> bool:
is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1
is_valid = detect_protocol_version(transport) == ProtocolVersion.V1
if not is_valid:
LOG.warning("Detected unsupported Bridge transport!")
return is_valid

View File

@ -59,7 +59,7 @@ def test_reset_slip39_basic(
entropy_check_count=0,
_get_entropy=MOCK_GET_ENTROPY,
)
if device_handler.client.protocol_version is ProtocolVersion.PROTOCOL_V2:
if device_handler.client.protocol_version is ProtocolVersion.V2:
reset.confirm_read(debug, middle_r=True)
# confirm new wallet

View File

@ -293,23 +293,17 @@ def _client_unlocked(
args = protocol_marker.args
protocol_version = _raw_client.protocol_version
if (
protocol_version == ProtocolVersion.PROTOCOL_V1
and "protocol_v1" not in args
):
if protocol_version == ProtocolVersion.V1 and "protocol_v1" not in args:
pytest.skip(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
)
if (
protocol_version == ProtocolVersion.PROTOCOL_V2
and "protocol_v2" not in args
):
if protocol_version == ProtocolVersion.V2 and "protocol_v2" not in args:
pytest.skip(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
)
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2:
if _raw_client.protocol_version is ProtocolVersion.V2:
pass
sd_marker = request.node.get_closest_marker("sd_card")
if sd_marker and not _raw_client.features.sd_card_present:

View File

@ -808,7 +808,7 @@ def test_multisession_authorization(client: Client):
)
# Open a second session.
if client.protocol_version is ProtocolVersion.PROTOCOL_V2:
if client.protocol_version is ProtocolVersion.V2:
session_id = b"\x02"
else:
session_id = None

View File

@ -1,10 +1,12 @@
import os
from time import sleep
import pytest
from trezorlib import messages
from trezorlib.client import ProtocolV2Channel
from trezorlib.debuglink import TrezorClientDebugLink as Client
from ...conftest import LOCK_TIME
pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client]

View File

@ -8,7 +8,6 @@ import pytest
from _pytest.nodes import Node
from _pytest.outcomes import Failed
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import TrezorClientDebugLink as Client
from . import common