mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-04-15 06:45:59 +00:00
fix style and errors
This commit is contained in:
parent
68ac7fb0ea
commit
33547c80f4
@ -102,11 +102,9 @@ if utils.USE_THP:
|
||||
return # pylint: disable=lost-exception
|
||||
|
||||
else:
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
async def handle_session(iface: WireInterface) -> None:
|
||||
ctx = CodecContext(iface, WIRE_BUFFER)
|
||||
ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER)
|
||||
next_msg: protocol_common.Message | None = None
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
|
@ -161,8 +161,8 @@ class Channel:
|
||||
pass # TODO ??
|
||||
if self.fallback_decrypt and self.expected_payload_length == self.bytes_read:
|
||||
self._finish_fallback()
|
||||
from trezor.messages import Failure
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
|
||||
return self.write(
|
||||
Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"),
|
||||
@ -201,7 +201,7 @@ class Channel:
|
||||
return None
|
||||
return self._handle_init_packet(packet)
|
||||
|
||||
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||
def _handle_init_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||
self.fallback_decrypt = False
|
||||
self.fallback_session_id = None
|
||||
self.bytes_read = 0
|
||||
|
@ -135,11 +135,11 @@ class TrezorConnection:
|
||||
|
||||
# Try resume session from id
|
||||
if self.session_id is not None:
|
||||
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
||||
if client.protocol_version is ProtocolVersion.V1:
|
||||
session = SessionV1.resume_from_id(
|
||||
client=client, session_id=self.session_id
|
||||
)
|
||||
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
elif client.protocol_version is ProtocolVersion.V2:
|
||||
session = SessionV2(client, self.session_id)
|
||||
# TODO fix resumption on THP
|
||||
else:
|
||||
|
@ -95,7 +95,7 @@ class TrezorClient:
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
self._protocol_version = ProtocolVersion.V1
|
||||
elif isinstance(self.protocol, ProtocolV2Channel):
|
||||
self._protocol_version = ProtocolVersion.PROTOCOL_V2
|
||||
self._protocol_version = ProtocolVersion.V2
|
||||
else:
|
||||
raise Exception("Unknown protocol version")
|
||||
|
||||
@ -129,7 +129,7 @@ class TrezorClient:
|
||||
from .transport.session import SessionV2
|
||||
|
||||
assert isinstance(passphrase, str) or passphrase is None
|
||||
session_id = 1 # TODO fix this with ProtocolV2 session rework
|
||||
session_id = b"\x01" # TODO fix this with ProtocolV2 session rework
|
||||
if session_id is not None:
|
||||
sid = int.from_bytes(session_id, "big")
|
||||
else:
|
||||
@ -197,9 +197,9 @@ class TrezorClient:
|
||||
return protocol
|
||||
|
||||
def reset_protocol(self):
|
||||
if self._protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||
if self._protocol_version == ProtocolVersion.V1:
|
||||
self.protocol = ProtocolV1Channel(self.transport, self.mapping)
|
||||
elif self._protocol_version == ProtocolVersion.PROTOCOL_V2:
|
||||
elif self._protocol_version == ProtocolVersion.V2:
|
||||
self.protocol = ProtocolV2Channel(self.transport, self.mapping)
|
||||
else:
|
||||
assert False
|
||||
@ -217,11 +217,23 @@ class TrezorClient:
|
||||
else:
|
||||
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
self.protocol.write(msg)
|
||||
def _write(self, msg: t.Any, session_id: int | None = None) -> None:
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
self.protocol.write(msg)
|
||||
elif isinstance(self.protocol, ProtocolV2Channel):
|
||||
assert session_id is not None
|
||||
self.protocol.write(session_id=session_id, msg=msg)
|
||||
else:
|
||||
raise Exception("Unknown client protocol")
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
return self.protocol.read()
|
||||
def _read(self, session_id: int | None = None) -> t.Any:
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
self.protocol.read()
|
||||
elif isinstance(self.protocol, ProtocolV2Channel):
|
||||
assert session_id is not None
|
||||
self.protocol.read(session_id=session_id)
|
||||
else:
|
||||
raise Exception("Unknown client protocol")
|
||||
|
||||
|
||||
def get_default_client(
|
||||
|
@ -32,20 +32,13 @@ from pathlib import Path
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import btc, mapping, messages, models, protobuf
|
||||
from .client import (
|
||||
MAX_PASSPHRASE_LENGTH,
|
||||
MAX_PIN_LENGTH,
|
||||
PASSPHRASE_ON_DEVICE,
|
||||
ProtocolVersion,
|
||||
TrezorClient,
|
||||
)
|
||||
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
|
||||
from .client import ProtocolVersion, TrezorClient
|
||||
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability, DebugWaitType
|
||||
from .protobuf import MessageType
|
||||
from .messages import DebugWaitType
|
||||
from .tools import parse_path
|
||||
from .transport import Timeout
|
||||
from .transport.session import Session, SessionV1, derive_seed
|
||||
from .transport.session import Session
|
||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
|
@ -76,7 +76,7 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
|
||||
from .. import mapping, messages
|
||||
from ..messages import FailureType
|
||||
|
||||
protocol_version = ProtocolVersion.PROTOCOL_V1
|
||||
protocol_version = ProtocolVersion.V1
|
||||
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
||||
transport.open()
|
||||
transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
|
||||
@ -87,13 +87,13 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
|
||||
if isinstance(response, messages.Failure):
|
||||
if response.code == FailureType.InvalidProtocol:
|
||||
LOG.debug("Protocol V2 detected")
|
||||
protocol_version = ProtocolVersion.PROTOCOL_V2
|
||||
protocol_version = ProtocolVersion.V2
|
||||
|
||||
return protocol_version
|
||||
|
||||
|
||||
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
||||
is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1
|
||||
is_valid = detect_protocol_version(transport) == ProtocolVersion.V1
|
||||
if not is_valid:
|
||||
LOG.warning("Detected unsupported Bridge transport!")
|
||||
return is_valid
|
||||
|
@ -59,7 +59,7 @@ def test_reset_slip39_basic(
|
||||
entropy_check_count=0,
|
||||
_get_entropy=MOCK_GET_ENTROPY,
|
||||
)
|
||||
if device_handler.client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
if device_handler.client.protocol_version is ProtocolVersion.V2:
|
||||
reset.confirm_read(debug, middle_r=True)
|
||||
|
||||
# confirm new wallet
|
||||
|
@ -293,23 +293,17 @@ def _client_unlocked(
|
||||
args = protocol_marker.args
|
||||
protocol_version = _raw_client.protocol_version
|
||||
|
||||
if (
|
||||
protocol_version == ProtocolVersion.PROTOCOL_V1
|
||||
and "protocol_v1" not in args
|
||||
):
|
||||
if protocol_version == ProtocolVersion.V1 and "protocol_v1" not in args:
|
||||
pytest.skip(
|
||||
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||
)
|
||||
|
||||
if (
|
||||
protocol_version == ProtocolVersion.PROTOCOL_V2
|
||||
and "protocol_v2" not in args
|
||||
):
|
||||
if protocol_version == ProtocolVersion.V2 and "protocol_v2" not in args:
|
||||
pytest.skip(
|
||||
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||
)
|
||||
|
||||
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
if _raw_client.protocol_version is ProtocolVersion.V2:
|
||||
pass
|
||||
sd_marker = request.node.get_closest_marker("sd_card")
|
||||
if sd_marker and not _raw_client.features.sd_card_present:
|
||||
|
@ -808,7 +808,7 @@ def test_multisession_authorization(client: Client):
|
||||
)
|
||||
|
||||
# Open a second session.
|
||||
if client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
||||
if client.protocol_version is ProtocolVersion.V2:
|
||||
session_id = b"\x02"
|
||||
else:
|
||||
session_id = None
|
||||
|
@ -1,10 +1,12 @@
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import messages
|
||||
from trezorlib.client import ProtocolV2Channel
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
|
||||
from ...conftest import LOCK_TIME
|
||||
|
||||
pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client]
|
||||
|
@ -8,7 +8,6 @@ import pytest
|
||||
from _pytest.nodes import Node
|
||||
from _pytest.outcomes import Failed
|
||||
|
||||
from trezorlib.client import ProtocolVersion
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
|
||||
from . import common
|
||||
|
Loading…
Reference in New Issue
Block a user