diff --git a/core/src/main.py b/core/src/main.py index ff9637ebe..846b99977 100644 --- a/core/src/main.py +++ b/core/src/main.py @@ -34,7 +34,7 @@ import trezor.pin # noqa: F401, E402 # === Prepare the USB interfaces first. Do not connect to the host yet. # usb imports trezor.utils and trezor.io which is a C module -import usb # noqa:E402 +import usb # noqa: E402 # create an unimport manager that will be reused in the main loop unimport_manager = utils.unimport() @@ -45,7 +45,7 @@ with unimport_manager: del boot # start the USB -import storage.device # noqa:E402 +import storage.device # noqa: E402 usb.bus.open(storage.device.get_device_id()) diff --git a/core/src/storage/cache_codec.py b/core/src/storage/cache_codec.py index 2be0fd321..086a7786f 100644 --- a/core/src/storage/cache_codec.py +++ b/core/src/storage/cache_codec.py @@ -12,12 +12,12 @@ if TYPE_CHECKING: _MAX_SESSIONS_COUNT = const(10) -_SESSION_ID_LENGTH = const(32) +SESSION_ID_LENGTH = const(32) class SessionCache(DataCache): def __init__(self) -> None: - self.session_id = bytearray(_SESSION_ID_LENGTH) + self.session_id = bytearray(SESSION_ID_LENGTH) if utils.BITCOIN_ONLY: self.fields = ( 64, # APP_COMMON_SEED @@ -44,7 +44,7 @@ class SessionCache(DataCache): # generate a new session id if we don't have it yet if not self.session_id: - self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) + self.session_id[:] = random.bytes(SESSION_ID_LENGTH) # export it as immutable bytes return bytes(self.session_id) @@ -85,7 +85,7 @@ def start_session(received_session_id: bytes | None = None) -> bytes: if ( received_session_id is not None - and len(received_session_id) != _SESSION_ID_LENGTH + and len(received_session_id) != SESSION_ID_LENGTH ): # Prevent the caller from setting received_session_id=b"" and finding a cleared # session. More generally, short-circuit the session id search, because we know diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index b81147a4e..2cf66ece4 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -1,12 +1,12 @@ import builtins -from micropython import const -from typing import TYPE_CHECKING +from micropython import const # pyright: ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from storage.cache_common import DataCache, InvalidSessionError from trezor import utils if TYPE_CHECKING: - from typing import TypeVar + from typing import TypeVar # pyright: ignore[reportShadowedImports] T = TypeVar("T") diff --git a/core/src/trezor/utils.py b/core/src/trezor/utils.py index 7b7f00c18..119a34c8c 100644 --- a/core/src/trezor/utils.py +++ b/core/src/trezor/utils.py @@ -1,6 +1,6 @@ import gc import sys -from trezorutils import ( # noqa: F401 +from trezorutils import ( # noqa: F401 # pyright: ignore[reportMissingImports] BITCOIN_ONLY, EMULATOR, INTERNAL_MODEL, @@ -25,7 +25,7 @@ from trezorutils import ( # noqa: F401 unit_color, unit_packaging, ) -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] # Will get replaced by "True" / "False" in the build process # However, needs to stay as an exported symbol for the unit tests @@ -37,7 +37,7 @@ USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/... if __debug__: if EMULATOR: - import uos + import uos # pyright: ignore[reportMissingModuleSource] DISABLE_ANIMATION = int(uos.getenv("TREZOR_DISABLE_ANIMATION") or "0") LOG_MEMORY = int(uos.getenv("TREZOR_LOG_MEMORY") or "0") @@ -45,7 +45,13 @@ if __debug__: LOG_MEMORY = 0 if TYPE_CHECKING: - from typing import Any, Iterator, Protocol, Sequence, TypeVar + from typing import ( # pyright: ignore[reportShadowedImports] + Any, + Iterator, + Protocol, + Sequence, + TypeVar, + ) from trezor.protobuf import MessageType @@ -113,13 +119,13 @@ def presize_module(modname: str, size: int) -> None: if __debug__: def mem_dump(filename: str) -> None: - from micropython import mem_info + from micropython import mem_info # pyright: ignore[reportMissingModuleSource] print(f"### sysmodules ({len(sys.modules)}):") for mod in sys.modules: print("*", mod) if EMULATOR: - from trezorutils import meminfo + from trezorutils import meminfo # pyright: ignore[reportMissingImports] print("### dumping to", filename) meminfo(filename) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index c6cea0b31..55100e881 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -30,7 +30,7 @@ from storage.cache_common import InvalidSessionError from trezor import log, loop, protobuf, utils from trezor.enums import FailureType from trezor.messages import Failure -from trezor.wire import codec_v1, context, message_handler, protocol_common, thp_v1 +from trezor.wire import context, message_handler, protocol_common, thp_v1 from trezor.wire.errors import DataError, Error # Import all errors into namespace, so that `wire.Error` is available from @@ -63,7 +63,7 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None: if utils.USE_THP: loop.schedule(handle_thp_session(iface, is_debug_session)) else: - loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session)) + loop.schedule(handle_session(iface, is_debug_session)) def wrap_protobuf_load( @@ -128,15 +128,12 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals print("Exception raised:", exc) -async def handle_session( - iface: WireInterface, codec_session_id: int, is_debug_session: bool = False -) -> None: +async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None: if __debug__ and is_debug_session: ctx_buffer = WIRE_BUFFER_DEBUG else: ctx_buffer = WIRE_BUFFER - session_id = codec_session_id.to_bytes(4, "big") - ctx = context.CodecContext(iface, ctx_buffer, session_id) + ctx = context.CodecContext(iface, ctx_buffer) next_msg: protocol_common.MessageWithId | None = None if __debug__ and is_debug_session: @@ -165,10 +162,6 @@ async def handle_session( msg = next_msg next_msg = None - # Set ctx.session_id to the value msg.session_id - if msg.session_id is not None: - ctx.channel_id = msg.session_id - try: next_msg_without_id = await message_handler.handle_single_message( ctx, msg, use_workflow=not is_debug_session diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index d20c85edd..5eca51898 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -65,12 +65,10 @@ class CodecContext(Context): self, iface: WireInterface, buffer: bytearray, - channel_id: bytes, ) -> None: self.iface = iface self.buffer = buffer - self.channel_id = channel_id - super().__init__(iface, channel_id) + super().__init__(iface, codec_v1.SESSION_ID.to_bytes(2, "big")) def read_from_wire(self) -> Awaitable[MessageWithId]: """Read a whole message from the wire without parsing it.""" @@ -100,15 +98,10 @@ class CodecContext(Context): to save on having to decode the type code into a protobuf class. """ if __debug__: - if self.channel_id is not None: - sid = int.from_bytes(self.channel_id, "big") - else: - sid = -1 log.debug( __name__, - "%s:%x expect: %s", + "%s: expect: %s", self.iface.iface_num(), - sid, expected_type.MESSAGE_NAME if expected_type else expected_types, ) @@ -120,22 +113,14 @@ class CodecContext(Context): if msg.type not in expected_types: raise UnexpectedMessageWithId(msg) - # TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError - # (and maybe update ctx.session_id - depends on expected behaviour) - if expected_type is None: expected_type = protobuf.type_for_wire(msg.type) if __debug__: - if self.channel_id is not None: - sid = int.from_bytes(self.channel_id, "big") - else: - sid = -1 log.debug( __name__, - "%s:%x read: %s", + "%s: read: %s", self.iface.iface_num(), - sid, expected_type.MESSAGE_NAME, ) @@ -147,15 +132,10 @@ class CodecContext(Context): async def write(self, msg: protobuf.MessageType) -> None: """Write a message to the wire.""" if __debug__: - if self.channel_id is not None: - sid = int.from_bytes(self.channel_id, "big") - else: - sid = -1 log.debug( __name__, - "%s:%x write: %s", + "%s: write: %s", self.iface.iface_num(), - sid, msg.MESSAGE_NAME, ) diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index fdea9cbae..3754cf726 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,4 +1,4 @@ -from common import * # isort:skip +from common import * # isort:skip # noqa: F403 from mock_storage import mock_storage @@ -23,7 +23,9 @@ def is_session_started() -> bool: return _PROTOCOL_CACHE.get_active_session() is not None -class TestStorageCache(unittest.TestCase): +class TestStorageCache( + unittest.TestCase +): # noqa: F405 # pyright: ignore[reportUndefinedVariable] def setUp(self): cache.clear_all() diff --git a/core/tests/test_trezor.crypto.base32.py b/core/tests/test_trezor.crypto.base32.py index 913a386b4..58bbb5d4d 100644 --- a/core/tests/test_trezor.crypto.base32.py +++ b/core/tests/test_trezor.crypto.base32.py @@ -26,12 +26,12 @@ class TestCryptoBase32(unittest.TestCase): "PJWHK5DPOVRWW6JANN2W4IDVOBSWYIDEMFRGK3DTNNSSA33EPE======", ), # fmt: off - (b"中文", "4S4K3ZUWQ4======"), # noqa:E999 - (b"中文1", "4S4K3ZUWQ4YQ===="), # noqa:E999 - (b"中文12", "4S4K3ZUWQ4YTE==="), # noqa:E999 - (b"aécio", "MHB2SY3JN4======"), # noqa:E999 - (b"𠜎", "6CQJZDQ="), # noqa:E999 - (b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa:E999 + (b"中文", "4S4K3ZUWQ4======"), # noqa: E999 + (b"中文1", "4S4K3ZUWQ4YQ===="), # noqa: E999 + (b"中文12", "4S4K3ZUWQ4YTE==="), # noqa: E999 + (b"aécio", "MHB2SY3JN4======"), # noqa: E999 + (b"𠜎", "6CQJZDQ="), # noqa: E999 + (b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa: E999 "IJQXGZJWGTTJRL7EXCAOPKFO4WP3VZUWXQ3DJZMARPSY7L7FRCL6LDNQ4WWZPZMFQPSL5BXIUGUOPJF24S5IZ2MAWLSYRNXIWOD6NFUZ46NIJ2FBVDT2JOXGS246NM4V"), # fmt: on ] diff --git a/core/tests/test_trezor.wire.thp_v1.py b/core/tests/test_trezor.wire.thp_v1.py index 537955a85..e43c128ba 100644 --- a/core/tests/test_trezor.wire.thp_v1.py +++ b/core/tests/test_trezor.wire.thp_v1.py @@ -79,7 +79,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): def setUp(self): self.interface = MockHID(0xDEADBEEF) if not utils.USE_THP: - import storage.cache_thp # noQA:F401 + import storage.cache_thp # noqa: F401 def test_simple(self): cid_req_header = make_header(