1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-16 04:58:45 +00:00

fix style and errors

This commit is contained in:
M1nd3r 2025-04-08 16:16:48 +02:00
parent 317a8cb3cf
commit 039b74ce56
11 changed files with 39 additions and 41 deletions

View File

@ -102,11 +102,9 @@ if utils.USE_THP:
return # pylint: disable=lost-exception return # pylint: disable=lost-exception
else: else:
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
async def handle_session(iface: WireInterface) -> None: async def handle_session(iface: WireInterface) -> None:
ctx = CodecContext(iface, WIRE_BUFFER) ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER)
next_msg: protocol_common.Message | None = None next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can # Take a mark of modules that are imported at this point, so we can

View File

@ -161,8 +161,8 @@ class Channel:
pass # TODO ?? pass # TODO ??
if self.fallback_decrypt and self.expected_payload_length == self.bytes_read: if self.fallback_decrypt and self.expected_payload_length == self.bytes_read:
self._finish_fallback() self._finish_fallback()
from trezor.messages import Failure
from trezor.enums import FailureType from trezor.enums import FailureType
from trezor.messages import Failure
return self.write( return self.write(
Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"), Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"),
@ -201,7 +201,7 @@ class Channel:
return None return None
return self._handle_init_packet(packet) return self._handle_init_packet(packet)
def _handle_init_packet(self, packet: utils.BufferType) -> None: def _handle_init_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
self.fallback_decrypt = False self.fallback_decrypt = False
self.fallback_session_id = None self.fallback_session_id = None
self.bytes_read = 0 self.bytes_read = 0

View File

@ -135,11 +135,11 @@ class TrezorConnection:
# Try resume session from id # Try resume session from id
if self.session_id is not None: if self.session_id is not None:
if client.protocol_version is ProtocolVersion.PROTOCOL_V1: if client.protocol_version is ProtocolVersion.V1:
session = SessionV1.resume_from_id( session = SessionV1.resume_from_id(
client=client, session_id=self.session_id client=client, session_id=self.session_id
) )
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: elif client.protocol_version is ProtocolVersion.V2:
session = SessionV2(client, self.session_id) session = SessionV2(client, self.session_id)
# TODO fix resumption on THP # TODO fix resumption on THP
else: else:

View File

@ -95,7 +95,7 @@ class TrezorClient:
if isinstance(self.protocol, ProtocolV1Channel): if isinstance(self.protocol, ProtocolV1Channel):
self._protocol_version = ProtocolVersion.V1 self._protocol_version = ProtocolVersion.V1
elif isinstance(self.protocol, ProtocolV2Channel): elif isinstance(self.protocol, ProtocolV2Channel):
self._protocol_version = ProtocolVersion.PROTOCOL_V2 self._protocol_version = ProtocolVersion.V2
else: else:
raise Exception("Unknown protocol version") raise Exception("Unknown protocol version")
@ -129,7 +129,7 @@ class TrezorClient:
from .transport.session import SessionV2 from .transport.session import SessionV2
assert isinstance(passphrase, str) or passphrase is None assert isinstance(passphrase, str) or passphrase is None
session_id = 1 # TODO fix this with ProtocolV2 session rework session_id = b"\x01" # TODO fix this with ProtocolV2 session rework
if session_id is not None: if session_id is not None:
sid = int.from_bytes(session_id, "big") sid = int.from_bytes(session_id, "big")
else: else:
@ -197,9 +197,9 @@ class TrezorClient:
return protocol return protocol
def reset_protocol(self): def reset_protocol(self):
if self._protocol_version == ProtocolVersion.PROTOCOL_V1: if self._protocol_version == ProtocolVersion.V1:
self.protocol = ProtocolV1Channel(self.transport, self.mapping) self.protocol = ProtocolV1Channel(self.transport, self.mapping)
elif self._protocol_version == ProtocolVersion.PROTOCOL_V2: elif self._protocol_version == ProtocolVersion.V2:
self.protocol = ProtocolV2Channel(self.transport, self.mapping) self.protocol = ProtocolV2Channel(self.transport, self.mapping)
else: else:
assert False assert False
@ -217,11 +217,23 @@ class TrezorClient:
else: else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
def _write(self, msg: t.Any) -> None: def _write(self, msg: t.Any, session_id: int | None = None) -> None:
self.protocol.write(msg) if isinstance(self.protocol, ProtocolV1Channel):
self.protocol.write(msg)
elif isinstance(self.protocol, ProtocolV2Channel):
assert session_id is not None
self.protocol.write(session_id=session_id, msg=msg)
else:
raise Exception("Unknown client protocol")
def _read(self) -> t.Any: def _read(self, session_id: int | None = None) -> t.Any:
return self.protocol.read() if isinstance(self.protocol, ProtocolV1Channel):
self.protocol.read()
elif isinstance(self.protocol, ProtocolV2Channel):
assert session_id is not None
self.protocol.read(session_id=session_id)
else:
raise Exception("Unknown client protocol")
def get_default_client( def get_default_client(

View File

@ -32,20 +32,13 @@ from pathlib import Path
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import btc, mapping, messages, models, protobuf from . import btc, mapping, messages, models, protobuf
from .client import ( from .client import ProtocolVersion, TrezorClient
MAX_PASSPHRASE_LENGTH, from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
MAX_PIN_LENGTH,
PASSPHRASE_ON_DEVICE,
ProtocolVersion,
TrezorClient,
)
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType from .messages import DebugWaitType
from .protobuf import MessageType
from .tools import parse_path from .tools import parse_path
from .transport import Timeout from .transport import Timeout
from .transport.session import Session, SessionV1, derive_seed from .transport.session import Session
from .transport.thp.protocol_v1 import ProtocolV1Channel from .transport.thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING: if t.TYPE_CHECKING:

View File

@ -76,7 +76,7 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
from .. import mapping, messages from .. import mapping, messages
from ..messages import FailureType from ..messages import FailureType
protocol_version = ProtocolVersion.PROTOCOL_V1 protocol_version = ProtocolVersion.V1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
transport.open() transport.open()
transport.write_chunk(request_type.to_bytes(2, "big") + request_data) transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
@ -87,13 +87,13 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
if isinstance(response, messages.Failure): if isinstance(response, messages.Failure):
if response.code == FailureType.InvalidProtocol: if response.code == FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected") LOG.debug("Protocol V2 detected")
protocol_version = ProtocolVersion.PROTOCOL_V2 protocol_version = ProtocolVersion.V2
return protocol_version return protocol_version
def _is_transport_valid(transport: "BridgeTransport") -> bool: def _is_transport_valid(transport: "BridgeTransport") -> bool:
is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1 is_valid = detect_protocol_version(transport) == ProtocolVersion.V1
if not is_valid: if not is_valid:
LOG.warning("Detected unsupported Bridge transport!") LOG.warning("Detected unsupported Bridge transport!")
return is_valid return is_valid

View File

@ -59,7 +59,7 @@ def test_reset_slip39_basic(
entropy_check_count=0, entropy_check_count=0,
_get_entropy=MOCK_GET_ENTROPY, _get_entropy=MOCK_GET_ENTROPY,
) )
if device_handler.client.protocol_version is ProtocolVersion.PROTOCOL_V2: if device_handler.client.protocol_version is ProtocolVersion.V2:
reset.confirm_read(debug, middle_r=True) reset.confirm_read(debug, middle_r=True)
# confirm new wallet # confirm new wallet

View File

@ -293,23 +293,17 @@ def _client_unlocked(
args = protocol_marker.args args = protocol_marker.args
protocol_version = _raw_client.protocol_version protocol_version = _raw_client.protocol_version
if ( if protocol_version == ProtocolVersion.V1 and "protocol_v1" not in args:
protocol_version == ProtocolVersion.PROTOCOL_V1
and "protocol_v1" not in args
):
pytest.skip( pytest.skip(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
) )
if ( if protocol_version == ProtocolVersion.V2 and "protocol_v2" not in args:
protocol_version == ProtocolVersion.PROTOCOL_V2
and "protocol_v2" not in args
):
pytest.skip( pytest.skip(
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported." f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
) )
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2: if _raw_client.protocol_version is ProtocolVersion.V2:
pass pass
sd_marker = request.node.get_closest_marker("sd_card") sd_marker = request.node.get_closest_marker("sd_card")
if sd_marker and not _raw_client.features.sd_card_present: if sd_marker and not _raw_client.features.sd_card_present:

View File

@ -808,7 +808,7 @@ def test_multisession_authorization(client: Client):
) )
# Open a second session. # Open a second session.
if client.protocol_version is ProtocolVersion.PROTOCOL_V2: if client.protocol_version is ProtocolVersion.V2:
session_id = b"\x02" session_id = b"\x02"
else: else:
session_id = None session_id = None

View File

@ -1,10 +1,12 @@
import os import os
from time import sleep from time import sleep
import pytest import pytest
from trezorlib import messages from trezorlib import messages
from trezorlib.client import ProtocolV2Channel from trezorlib.client import ProtocolV2Channel
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from ...conftest import LOCK_TIME from ...conftest import LOCK_TIME
pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client] pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client]

View File

@ -8,7 +8,6 @@ import pytest
from _pytest.nodes import Node from _pytest.nodes import Node
from _pytest.outcomes import Failed from _pytest.outcomes import Failed
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from . import common from . import common