mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-13 19:18:56 +00:00
Remove unnecessary session_ids in old Codec, fix storage test, suppress invalid warnings
This commit is contained in:
parent
1ed78fb95d
commit
d0a5b48f36
@ -34,7 +34,7 @@ import trezor.pin # noqa: F401, E402
|
||||
|
||||
# === Prepare the USB interfaces first. Do not connect to the host yet.
|
||||
# usb imports trezor.utils and trezor.io which is a C module
|
||||
import usb # noqa:E402
|
||||
import usb # noqa: E402
|
||||
|
||||
# create an unimport manager that will be reused in the main loop
|
||||
unimport_manager = utils.unimport()
|
||||
@ -45,7 +45,7 @@ with unimport_manager:
|
||||
del boot
|
||||
|
||||
# start the USB
|
||||
import storage.device # noqa:E402
|
||||
import storage.device # noqa: E402
|
||||
|
||||
usb.bus.open(storage.device.get_device_id())
|
||||
|
||||
|
@ -12,12 +12,12 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = const(10)
|
||||
_SESSION_ID_LENGTH = const(32)
|
||||
SESSION_ID_LENGTH = const(32)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
@ -44,7 +44,7 @@ class SessionCache(DataCache):
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
@ -85,7 +85,7 @@ def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||
|
||||
if (
|
||||
received_session_id is not None
|
||||
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||
and len(received_session_id) != SESSION_ID_LENGTH
|
||||
):
|
||||
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||
# session. More generally, short-circuit the session id search, because we know
|
||||
|
@ -1,12 +1,12 @@
|
||||
import builtins
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage.cache_common import DataCache, InvalidSessionError
|
||||
from trezor import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # pyright: ignore[reportShadowedImports]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import gc
|
||||
import sys
|
||||
from trezorutils import ( # noqa: F401
|
||||
from trezorutils import ( # noqa: F401 # pyright: ignore[reportMissingImports]
|
||||
BITCOIN_ONLY,
|
||||
EMULATOR,
|
||||
INTERNAL_MODEL,
|
||||
@ -24,7 +24,7 @@ from trezorutils import ( # noqa: F401
|
||||
unit_btconly,
|
||||
unit_color,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
# Will get replaced by "True" / "False" in the build process
|
||||
# However, needs to stay as an exported symbol for the unit tests
|
||||
@ -36,7 +36,7 @@ USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/...
|
||||
|
||||
if __debug__:
|
||||
if EMULATOR:
|
||||
import uos
|
||||
import uos # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
DISABLE_ANIMATION = int(uos.getenv("TREZOR_DISABLE_ANIMATION") or "0")
|
||||
LOG_MEMORY = int(uos.getenv("TREZOR_LOG_MEMORY") or "0")
|
||||
@ -44,7 +44,13 @@ if __debug__:
|
||||
LOG_MEMORY = 0
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Iterator, Protocol, Sequence, TypeVar
|
||||
from typing import ( # pyright: ignore[reportShadowedImports]
|
||||
Any,
|
||||
Iterator,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from trezor.protobuf import MessageType
|
||||
|
||||
@ -112,13 +118,13 @@ def presize_module(modname: str, size: int) -> None:
|
||||
if __debug__:
|
||||
|
||||
def mem_dump(filename: str) -> None:
|
||||
from micropython import mem_info
|
||||
from micropython import mem_info # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
print(f"### sysmodules ({len(sys.modules)}):")
|
||||
for mod in sys.modules:
|
||||
print("*", mod)
|
||||
if EMULATOR:
|
||||
from trezorutils import meminfo
|
||||
from trezorutils import meminfo # pyright: ignore[reportMissingImports]
|
||||
|
||||
print("### dumping to", filename)
|
||||
meminfo(filename)
|
||||
|
@ -30,7 +30,7 @@ from storage.cache_common import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire import codec_v1, context, message_handler, protocol_common, thp_v1
|
||||
from trezor.wire import context, message_handler, protocol_common, thp_v1
|
||||
from trezor.wire.errors import DataError, Error
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
@ -63,7 +63,7 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
if utils.USE_THP:
|
||||
loop.schedule(handle_thp_session(iface, is_debug_session))
|
||||
else:
|
||||
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
|
||||
loop.schedule(handle_session(iface, is_debug_session))
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
@ -128,15 +128,12 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
|
||||
print("Exception raised:", exc)
|
||||
|
||||
|
||||
async def handle_session(
|
||||
iface: WireInterface, codec_session_id: int, is_debug_session: bool = False
|
||||
) -> None:
|
||||
async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
if __debug__ and is_debug_session:
|
||||
ctx_buffer = WIRE_BUFFER_DEBUG
|
||||
else:
|
||||
ctx_buffer = WIRE_BUFFER
|
||||
session_id = codec_session_id.to_bytes(4, "big")
|
||||
ctx = context.CodecContext(iface, ctx_buffer, session_id)
|
||||
ctx = context.CodecContext(iface, ctx_buffer)
|
||||
next_msg: protocol_common.MessageWithId | None = None
|
||||
|
||||
if __debug__ and is_debug_session:
|
||||
@ -165,10 +162,6 @@ async def handle_session(
|
||||
msg = next_msg
|
||||
next_msg = None
|
||||
|
||||
# Set ctx.session_id to the value msg.session_id
|
||||
if msg.session_id is not None:
|
||||
ctx.channel_id = msg.session_id
|
||||
|
||||
try:
|
||||
next_msg_without_id = await message_handler.handle_single_message(
|
||||
ctx, msg, use_workflow=not is_debug_session
|
||||
|
@ -65,12 +65,10 @@ class CodecContext(Context):
|
||||
self,
|
||||
iface: WireInterface,
|
||||
buffer: bytearray,
|
||||
channel_id: bytes,
|
||||
) -> None:
|
||||
self.iface = iface
|
||||
self.buffer = buffer
|
||||
self.channel_id = channel_id
|
||||
super().__init__(iface, channel_id)
|
||||
super().__init__(iface, codec_v1.SESSION_ID.to_bytes(2, "big"))
|
||||
|
||||
def read_from_wire(self) -> Awaitable[MessageWithId]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
@ -100,15 +98,10 @@ class CodecContext(Context):
|
||||
to save on having to decode the type code into a protobuf class.
|
||||
"""
|
||||
if __debug__:
|
||||
if self.channel_id is not None:
|
||||
sid = int.from_bytes(self.channel_id, "big")
|
||||
else:
|
||||
sid = -1
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x expect: %s",
|
||||
"%s: expect: %s",
|
||||
self.iface.iface_num(),
|
||||
sid,
|
||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||
)
|
||||
|
||||
@ -120,22 +113,14 @@ class CodecContext(Context):
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessageWithId(msg)
|
||||
|
||||
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
|
||||
# (and maybe update ctx.session_id - depends on expected behaviour)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
if self.channel_id is not None:
|
||||
sid = int.from_bytes(self.channel_id, "big")
|
||||
else:
|
||||
sid = -1
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x read: %s",
|
||||
"%s: read: %s",
|
||||
self.iface.iface_num(),
|
||||
sid,
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
@ -147,15 +132,10 @@ class CodecContext(Context):
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
"""Write a message to the wire."""
|
||||
if __debug__:
|
||||
if self.channel_id is not None:
|
||||
sid = int.from_bytes(self.channel_id, "big")
|
||||
else:
|
||||
sid = -1
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x write: %s",
|
||||
"%s: write: %s",
|
||||
self.iface.iface_num(),
|
||||
sid,
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from common import * # isort:skip
|
||||
from common import * # isort:skip # noqa: F403
|
||||
|
||||
from mock_storage import mock_storage
|
||||
|
||||
@ -23,7 +23,9 @@ def is_session_started() -> bool:
|
||||
return _PROTOCOL_CACHE.get_active_session() is not None
|
||||
|
||||
|
||||
class TestStorageCache(unittest.TestCase):
|
||||
class TestStorageCache(
|
||||
unittest.TestCase
|
||||
): # noqa: F405 # pyright: ignore[reportUndefinedVariable]
|
||||
def setUp(self):
|
||||
cache.clear_all()
|
||||
|
||||
|
@ -26,12 +26,12 @@ class TestCryptoBase32(unittest.TestCase):
|
||||
"PJWHK5DPOVRWW6JANN2W4IDVOBSWYIDEMFRGK3DTNNSSA33EPE======",
|
||||
),
|
||||
# fmt: off
|
||||
(b"中文", "4S4K3ZUWQ4======"), # noqa:E999
|
||||
(b"中文1", "4S4K3ZUWQ4YQ===="), # noqa:E999
|
||||
(b"中文12", "4S4K3ZUWQ4YTE==="), # noqa:E999
|
||||
(b"aécio", "MHB2SY3JN4======"), # noqa:E999
|
||||
(b"𠜎", "6CQJZDQ="), # noqa:E999
|
||||
(b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa:E999
|
||||
(b"中文", "4S4K3ZUWQ4======"), # noqa: E999
|
||||
(b"中文1", "4S4K3ZUWQ4YQ===="), # noqa: E999
|
||||
(b"中文12", "4S4K3ZUWQ4YTE==="), # noqa: E999
|
||||
(b"aécio", "MHB2SY3JN4======"), # noqa: E999
|
||||
(b"𠜎", "6CQJZDQ="), # noqa: E999
|
||||
(b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa: E999
|
||||
"IJQXGZJWGTTJRL7EXCAOPKFO4WP3VZUWXQ3DJZMARPSY7L7FRCL6LDNQ4WWZPZMFQPSL5BXIUGUOPJF24S5IZ2MAWLSYRNXIWOD6NFUZ46NIJ2FBVDT2JOXGS246NM4V"),
|
||||
# fmt: on
|
||||
]
|
||||
|
@ -79,7 +79,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.interface = MockHID(0xDEADBEEF)
|
||||
if not utils.USE_THP:
|
||||
import storage.cache_thp # noQA:F401
|
||||
import storage.cache_thp # noqa: F401
|
||||
|
||||
def test_simple(self):
|
||||
cid_req_header = make_header(
|
||||
|
Loading…
Reference in New Issue
Block a user