Remove unnecessary session_ids in old Codec, fix storage test, suppress invalid warnings

M1nd3r/thp5
M1nd3r 1 month ago
parent e3c826c691
commit fe1ab40302

@ -34,7 +34,7 @@ import trezor.pin # noqa: F401, E402
# === Prepare the USB interfaces first. Do not connect to the host yet.
# usb imports trezor.utils and trezor.io which is a C module
import usb # noqa:E402
import usb # noqa: E402
# create an unimport manager that will be reused in the main loop
unimport_manager = utils.unimport()
@ -45,7 +45,7 @@ with unimport_manager:
del boot
# start the USB
import storage.device # noqa:E402
import storage.device # noqa: E402
usb.bus.open(storage.device.get_device_id())

@ -12,12 +12,12 @@ if TYPE_CHECKING:
_MAX_SESSIONS_COUNT = const(10)
_SESSION_ID_LENGTH = const(32)
SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.session_id = bytearray(SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
64, # APP_COMMON_SEED
@ -44,7 +44,7 @@ class SessionCache(DataCache):
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
@ -85,7 +85,7 @@ def start_session(received_session_id: bytes | None = None) -> bytes:
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
and len(received_session_id) != SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know

@ -1,12 +1,12 @@
import builtins
from micropython import const
from typing import TYPE_CHECKING
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage.cache_common import DataCache, InvalidSessionError
from trezor import utils
if TYPE_CHECKING:
from typing import TypeVar
from typing import TypeVar # pyright: ignore[reportShadowedImports]
T = TypeVar("T")

@ -1,6 +1,6 @@
import gc
import sys
from trezorutils import ( # noqa: F401
from trezorutils import ( # noqa: F401 # pyright: ignore[reportMissingImports]
BITCOIN_ONLY,
EMULATOR,
INTERNAL_MODEL,
@ -25,7 +25,7 @@ from trezorutils import ( # noqa: F401
unit_color,
unit_packaging,
)
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
# Will get replaced by "True" / "False" in the build process
# However, needs to stay as an exported symbol for the unit tests
@ -37,7 +37,7 @@ USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/...
if __debug__:
if EMULATOR:
import uos
import uos # pyright: ignore[reportMissingModuleSource]
DISABLE_ANIMATION = int(uos.getenv("TREZOR_DISABLE_ANIMATION") or "0")
LOG_MEMORY = int(uos.getenv("TREZOR_LOG_MEMORY") or "0")
@ -45,7 +45,13 @@ if __debug__:
LOG_MEMORY = 0
if TYPE_CHECKING:
from typing import Any, Iterator, Protocol, Sequence, TypeVar
from typing import ( # pyright: ignore[reportShadowedImports]
Any,
Iterator,
Protocol,
Sequence,
TypeVar,
)
from trezor.protobuf import MessageType
@ -113,13 +119,13 @@ def presize_module(modname: str, size: int) -> None:
if __debug__:
def mem_dump(filename: str) -> None:
from micropython import mem_info
from micropython import mem_info # pyright: ignore[reportMissingModuleSource]
print(f"### sysmodules ({len(sys.modules)}):")
for mod in sys.modules:
print("*", mod)
if EMULATOR:
from trezorutils import meminfo
from trezorutils import meminfo # pyright: ignore[reportMissingImports]
print("### dumping to", filename)
meminfo(filename)

@ -30,7 +30,7 @@ from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire import codec_v1, context, message_handler, protocol_common, thp_v1
from trezor.wire import context, message_handler, protocol_common, thp_v1
from trezor.wire.errors import DataError, Error
# Import all errors into namespace, so that `wire.Error` is available from
@ -63,7 +63,7 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
if utils.USE_THP:
loop.schedule(handle_thp_session(iface, is_debug_session))
else:
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
loop.schedule(handle_session(iface, is_debug_session))
def wrap_protobuf_load(
@ -128,15 +128,12 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
print("Exception raised:", exc)
async def handle_session(
iface: WireInterface, codec_session_id: int, is_debug_session: bool = False
) -> None:
async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:
if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG
else:
ctx_buffer = WIRE_BUFFER
session_id = codec_session_id.to_bytes(4, "big")
ctx = context.CodecContext(iface, ctx_buffer, session_id)
ctx = context.CodecContext(iface, ctx_buffer)
next_msg: protocol_common.MessageWithId | None = None
if __debug__ and is_debug_session:
@ -165,10 +162,6 @@ async def handle_session(
msg = next_msg
next_msg = None
# Set ctx.session_id to the value msg.session_id
if msg.session_id is not None:
ctx.channel_id = msg.session_id
try:
next_msg_without_id = await message_handler.handle_single_message(
ctx, msg, use_workflow=not is_debug_session

@ -65,12 +65,10 @@ class CodecContext(Context):
self,
iface: WireInterface,
buffer: bytearray,
channel_id: bytes,
) -> None:
self.iface = iface
self.buffer = buffer
self.channel_id = channel_id
super().__init__(iface, channel_id)
super().__init__(iface, codec_v1.SESSION_ID.to_bytes(2, "big"))
def read_from_wire(self) -> Awaitable[MessageWithId]:
"""Read a whole message from the wire without parsing it."""
@ -100,15 +98,10 @@ class CodecContext(Context):
to save on having to decode the type code into a protobuf class.
"""
if __debug__:
if self.channel_id is not None:
sid = int.from_bytes(self.channel_id, "big")
else:
sid = -1
log.debug(
__name__,
"%s:%x expect: %s",
"%s: expect: %s",
self.iface.iface_num(),
sid,
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
@ -120,22 +113,14 @@ class CodecContext(Context):
if msg.type not in expected_types:
raise UnexpectedMessageWithId(msg)
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
# (and maybe update ctx.session_id - depends on expected behaviour)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
if __debug__:
if self.channel_id is not None:
sid = int.from_bytes(self.channel_id, "big")
else:
sid = -1
log.debug(
__name__,
"%s:%x read: %s",
"%s: read: %s",
self.iface.iface_num(),
sid,
expected_type.MESSAGE_NAME,
)
@ -147,15 +132,10 @@ class CodecContext(Context):
async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire."""
if __debug__:
if self.channel_id is not None:
sid = int.from_bytes(self.channel_id, "big")
else:
sid = -1
log.debug(
__name__,
"%s:%x write: %s",
"%s: write: %s",
self.iface.iface_num(),
sid,
msg.MESSAGE_NAME,
)

@ -1,4 +1,4 @@
from common import * # isort:skip
from common import * # isort:skip # noqa: F403
from mock_storage import mock_storage
@ -23,7 +23,9 @@ def is_session_started() -> bool:
return _PROTOCOL_CACHE.get_active_session() is not None
class TestStorageCache(unittest.TestCase):
class TestStorageCache(
unittest.TestCase
): # noqa: F405 # pyright: ignore[reportUndefinedVariable]
def setUp(self):
cache.clear_all()

@ -26,12 +26,12 @@ class TestCryptoBase32(unittest.TestCase):
"PJWHK5DPOVRWW6JANN2W4IDVOBSWYIDEMFRGK3DTNNSSA33EPE======",
),
# fmt: off
(b"中文", "4S4K3ZUWQ4======"), # noqa:E999
(b"中文1", "4S4K3ZUWQ4YQ===="), # noqa:E999
(b"中文12", "4S4K3ZUWQ4YTE==="), # noqa:E999
(b"aécio", "MHB2SY3JN4======"), # noqa:E999
(b"𠜎", "6CQJZDQ="), # noqa:E999
(b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa:E999
(b"中文", "4S4K3ZUWQ4======"), # noqa: E999
(b"中文1", "4S4K3ZUWQ4YQ===="), # noqa: E999
(b"中文12", "4S4K3ZUWQ4YTE==="), # noqa: E999
(b"aécio", "MHB2SY3JN4======"), # noqa: E999
(b"𠜎", "6CQJZDQ="), # noqa: E999
(b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa: E999
"IJQXGZJWGTTJRL7EXCAOPKFO4WP3VZUWXQ3DJZMARPSY7L7FRCL6LDNQ4WWZPZMFQPSL5BXIUGUOPJF24S5IZ2MAWLSYRNXIWOD6NFUZ46NIJ2FBVDT2JOXGS246NM4V"),
# fmt: on
]

@ -79,7 +79,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
if not utils.USE_THP:
import storage.cache_thp # noQA:F401
import storage.cache_thp # noqa: F401
def test_simple(self):
cid_req_header = make_header(

Loading…
Cancel
Save