mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-16 04:58:45 +00:00
fix style and errors
This commit is contained in:
parent
317a8cb3cf
commit
039b74ce56
@ -102,11 +102,9 @@ if utils.USE_THP:
|
|||||||
return # pylint: disable=lost-exception
|
return # pylint: disable=lost-exception
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
|
||||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
|
||||||
|
|
||||||
async def handle_session(iface: WireInterface) -> None:
|
async def handle_session(iface: WireInterface) -> None:
|
||||||
ctx = CodecContext(iface, WIRE_BUFFER)
|
ctx = CodecContext(iface, WIRE_BUFFER_PROVIDER)
|
||||||
next_msg: protocol_common.Message | None = None
|
next_msg: protocol_common.Message | None = None
|
||||||
|
|
||||||
# Take a mark of modules that are imported at this point, so we can
|
# Take a mark of modules that are imported at this point, so we can
|
||||||
|
@ -161,8 +161,8 @@ class Channel:
|
|||||||
pass # TODO ??
|
pass # TODO ??
|
||||||
if self.fallback_decrypt and self.expected_payload_length == self.bytes_read:
|
if self.fallback_decrypt and self.expected_payload_length == self.bytes_read:
|
||||||
self._finish_fallback()
|
self._finish_fallback()
|
||||||
from trezor.messages import Failure
|
|
||||||
from trezor.enums import FailureType
|
from trezor.enums import FailureType
|
||||||
|
from trezor.messages import Failure
|
||||||
|
|
||||||
return self.write(
|
return self.write(
|
||||||
Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"),
|
Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"),
|
||||||
@ -201,7 +201,7 @@ class Channel:
|
|||||||
return None
|
return None
|
||||||
return self._handle_init_packet(packet)
|
return self._handle_init_packet(packet)
|
||||||
|
|
||||||
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
def _handle_init_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||||
self.fallback_decrypt = False
|
self.fallback_decrypt = False
|
||||||
self.fallback_session_id = None
|
self.fallback_session_id = None
|
||||||
self.bytes_read = 0
|
self.bytes_read = 0
|
||||||
|
@ -135,11 +135,11 @@ class TrezorConnection:
|
|||||||
|
|
||||||
# Try resume session from id
|
# Try resume session from id
|
||||||
if self.session_id is not None:
|
if self.session_id is not None:
|
||||||
if client.protocol_version is ProtocolVersion.PROTOCOL_V1:
|
if client.protocol_version is ProtocolVersion.V1:
|
||||||
session = SessionV1.resume_from_id(
|
session = SessionV1.resume_from_id(
|
||||||
client=client, session_id=self.session_id
|
client=client, session_id=self.session_id
|
||||||
)
|
)
|
||||||
elif client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
elif client.protocol_version is ProtocolVersion.V2:
|
||||||
session = SessionV2(client, self.session_id)
|
session = SessionV2(client, self.session_id)
|
||||||
# TODO fix resumption on THP
|
# TODO fix resumption on THP
|
||||||
else:
|
else:
|
||||||
|
@ -95,7 +95,7 @@ class TrezorClient:
|
|||||||
if isinstance(self.protocol, ProtocolV1Channel):
|
if isinstance(self.protocol, ProtocolV1Channel):
|
||||||
self._protocol_version = ProtocolVersion.V1
|
self._protocol_version = ProtocolVersion.V1
|
||||||
elif isinstance(self.protocol, ProtocolV2Channel):
|
elif isinstance(self.protocol, ProtocolV2Channel):
|
||||||
self._protocol_version = ProtocolVersion.PROTOCOL_V2
|
self._protocol_version = ProtocolVersion.V2
|
||||||
else:
|
else:
|
||||||
raise Exception("Unknown protocol version")
|
raise Exception("Unknown protocol version")
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ class TrezorClient:
|
|||||||
from .transport.session import SessionV2
|
from .transport.session import SessionV2
|
||||||
|
|
||||||
assert isinstance(passphrase, str) or passphrase is None
|
assert isinstance(passphrase, str) or passphrase is None
|
||||||
session_id = 1 # TODO fix this with ProtocolV2 session rework
|
session_id = b"\x01" # TODO fix this with ProtocolV2 session rework
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
sid = int.from_bytes(session_id, "big")
|
sid = int.from_bytes(session_id, "big")
|
||||||
else:
|
else:
|
||||||
@ -197,9 +197,9 @@ class TrezorClient:
|
|||||||
return protocol
|
return protocol
|
||||||
|
|
||||||
def reset_protocol(self):
|
def reset_protocol(self):
|
||||||
if self._protocol_version == ProtocolVersion.PROTOCOL_V1:
|
if self._protocol_version == ProtocolVersion.V1:
|
||||||
self.protocol = ProtocolV1Channel(self.transport, self.mapping)
|
self.protocol = ProtocolV1Channel(self.transport, self.mapping)
|
||||||
elif self._protocol_version == ProtocolVersion.PROTOCOL_V2:
|
elif self._protocol_version == ProtocolVersion.V2:
|
||||||
self.protocol = ProtocolV2Channel(self.transport, self.mapping)
|
self.protocol = ProtocolV2Channel(self.transport, self.mapping)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
@ -217,11 +217,23 @@ class TrezorClient:
|
|||||||
else:
|
else:
|
||||||
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
|
||||||
|
|
||||||
def _write(self, msg: t.Any) -> None:
|
def _write(self, msg: t.Any, session_id: int | None = None) -> None:
|
||||||
self.protocol.write(msg)
|
if isinstance(self.protocol, ProtocolV1Channel):
|
||||||
|
self.protocol.write(msg)
|
||||||
|
elif isinstance(self.protocol, ProtocolV2Channel):
|
||||||
|
assert session_id is not None
|
||||||
|
self.protocol.write(session_id=session_id, msg=msg)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown client protocol")
|
||||||
|
|
||||||
def _read(self) -> t.Any:
|
def _read(self, session_id: int | None = None) -> t.Any:
|
||||||
return self.protocol.read()
|
if isinstance(self.protocol, ProtocolV1Channel):
|
||||||
|
self.protocol.read()
|
||||||
|
elif isinstance(self.protocol, ProtocolV2Channel):
|
||||||
|
assert session_id is not None
|
||||||
|
self.protocol.read(session_id=session_id)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown client protocol")
|
||||||
|
|
||||||
|
|
||||||
def get_default_client(
|
def get_default_client(
|
||||||
|
@ -32,20 +32,13 @@ from pathlib import Path
|
|||||||
from mnemonic import Mnemonic
|
from mnemonic import Mnemonic
|
||||||
|
|
||||||
from . import btc, mapping, messages, models, protobuf
|
from . import btc, mapping, messages, models, protobuf
|
||||||
from .client import (
|
from .client import ProtocolVersion, TrezorClient
|
||||||
MAX_PASSPHRASE_LENGTH,
|
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
|
||||||
MAX_PIN_LENGTH,
|
|
||||||
PASSPHRASE_ON_DEVICE,
|
|
||||||
ProtocolVersion,
|
|
||||||
TrezorClient,
|
|
||||||
)
|
|
||||||
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
|
|
||||||
from .log import DUMP_BYTES
|
from .log import DUMP_BYTES
|
||||||
from .messages import Capability, DebugWaitType
|
from .messages import DebugWaitType
|
||||||
from .protobuf import MessageType
|
|
||||||
from .tools import parse_path
|
from .tools import parse_path
|
||||||
from .transport import Timeout
|
from .transport import Timeout
|
||||||
from .transport.session import Session, SessionV1, derive_seed
|
from .transport.session import Session
|
||||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
|
@ -76,7 +76,7 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
|
|||||||
from .. import mapping, messages
|
from .. import mapping, messages
|
||||||
from ..messages import FailureType
|
from ..messages import FailureType
|
||||||
|
|
||||||
protocol_version = ProtocolVersion.PROTOCOL_V1
|
protocol_version = ProtocolVersion.V1
|
||||||
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
||||||
transport.open()
|
transport.open()
|
||||||
transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
|
transport.write_chunk(request_type.to_bytes(2, "big") + request_data)
|
||||||
@ -87,13 +87,13 @@ def detect_protocol_version(transport: "BridgeTransport") -> int:
|
|||||||
if isinstance(response, messages.Failure):
|
if isinstance(response, messages.Failure):
|
||||||
if response.code == FailureType.InvalidProtocol:
|
if response.code == FailureType.InvalidProtocol:
|
||||||
LOG.debug("Protocol V2 detected")
|
LOG.debug("Protocol V2 detected")
|
||||||
protocol_version = ProtocolVersion.PROTOCOL_V2
|
protocol_version = ProtocolVersion.V2
|
||||||
|
|
||||||
return protocol_version
|
return protocol_version
|
||||||
|
|
||||||
|
|
||||||
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
||||||
is_valid = detect_protocol_version(transport) == ProtocolVersion.PROTOCOL_V1
|
is_valid = detect_protocol_version(transport) == ProtocolVersion.V1
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
LOG.warning("Detected unsupported Bridge transport!")
|
LOG.warning("Detected unsupported Bridge transport!")
|
||||||
return is_valid
|
return is_valid
|
||||||
|
@ -59,7 +59,7 @@ def test_reset_slip39_basic(
|
|||||||
entropy_check_count=0,
|
entropy_check_count=0,
|
||||||
_get_entropy=MOCK_GET_ENTROPY,
|
_get_entropy=MOCK_GET_ENTROPY,
|
||||||
)
|
)
|
||||||
if device_handler.client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
if device_handler.client.protocol_version is ProtocolVersion.V2:
|
||||||
reset.confirm_read(debug, middle_r=True)
|
reset.confirm_read(debug, middle_r=True)
|
||||||
|
|
||||||
# confirm new wallet
|
# confirm new wallet
|
||||||
|
@ -293,23 +293,17 @@ def _client_unlocked(
|
|||||||
args = protocol_marker.args
|
args = protocol_marker.args
|
||||||
protocol_version = _raw_client.protocol_version
|
protocol_version = _raw_client.protocol_version
|
||||||
|
|
||||||
if (
|
if protocol_version == ProtocolVersion.V1 and "protocol_v1" not in args:
|
||||||
protocol_version == ProtocolVersion.PROTOCOL_V1
|
|
||||||
and "protocol_v1" not in args
|
|
||||||
):
|
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if protocol_version == ProtocolVersion.V2 and "protocol_v2" not in args:
|
||||||
protocol_version == ProtocolVersion.PROTOCOL_V2
|
|
||||||
and "protocol_v2" not in args
|
|
||||||
):
|
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
f"Skipping test for device/emulator with protocol_v{protocol_version} - the protocol is not supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
if _raw_client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
if _raw_client.protocol_version is ProtocolVersion.V2:
|
||||||
pass
|
pass
|
||||||
sd_marker = request.node.get_closest_marker("sd_card")
|
sd_marker = request.node.get_closest_marker("sd_card")
|
||||||
if sd_marker and not _raw_client.features.sd_card_present:
|
if sd_marker and not _raw_client.features.sd_card_present:
|
||||||
|
@ -808,7 +808,7 @@ def test_multisession_authorization(client: Client):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Open a second session.
|
# Open a second session.
|
||||||
if client.protocol_version is ProtocolVersion.PROTOCOL_V2:
|
if client.protocol_version is ProtocolVersion.V2:
|
||||||
session_id = b"\x02"
|
session_id = b"\x02"
|
||||||
else:
|
else:
|
||||||
session_id = None
|
session_id = None
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from trezorlib import messages
|
from trezorlib import messages
|
||||||
from trezorlib.client import ProtocolV2Channel
|
from trezorlib.client import ProtocolV2Channel
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
|
|
||||||
from ...conftest import LOCK_TIME
|
from ...conftest import LOCK_TIME
|
||||||
|
|
||||||
pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client]
|
pytestmark = [pytest.mark.protocol("protocol_v2"), pytest.mark.invalidate_client]
|
||||||
|
@ -8,7 +8,6 @@ import pytest
|
|||||||
from _pytest.nodes import Node
|
from _pytest.nodes import Node
|
||||||
from _pytest.outcomes import Failed
|
from _pytest.outcomes import Failed
|
||||||
|
|
||||||
from trezorlib.client import ProtocolVersion
|
|
||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
|
|
||||||
from . import common
|
from . import common
|
||||||
|
Loading…
Reference in New Issue
Block a user