fix(core): remove unnecessary session_ids in old Codec, fix storage test, suppress invalid warnings

M1nd3r/thp6
M1nd3r 2 months ago
parent 5cd2fe938a
commit eb203faed1

@ -12,12 +12,12 @@ if TYPE_CHECKING:
_MAX_SESSIONS_COUNT = const(10) _MAX_SESSIONS_COUNT = const(10)
_SESSION_ID_LENGTH = const(32) SESSION_ID_LENGTH = const(32)
class SessionCache(DataCache): class SessionCache(DataCache):
def __init__(self) -> None: def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH) self.session_id = bytearray(SESSION_ID_LENGTH)
if utils.BITCOIN_ONLY: if utils.BITCOIN_ONLY:
self.fields = ( self.fields = (
64, # APP_COMMON_SEED 64, # APP_COMMON_SEED
@ -44,7 +44,7 @@ class SessionCache(DataCache):
# generate a new session id if we don't have it yet # generate a new session id if we don't have it yet
if not self.session_id: if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) self.session_id[:] = random.bytes(SESSION_ID_LENGTH)
# export it as immutable bytes # export it as immutable bytes
return bytes(self.session_id) return bytes(self.session_id)
@ -85,7 +85,7 @@ def start_session(received_session_id: bytes | None = None) -> bytes:
if ( if (
received_session_id is not None received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH and len(received_session_id) != SESSION_ID_LENGTH
): ):
# Prevent the caller from setting received_session_id=b"" and finding a cleared # Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know # session. More generally, short-circuit the session id search, because we know

@ -1,12 +1,12 @@
import builtins import builtins
from micropython import const from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage.cache_common import DataCache, InvalidSessionError from storage.cache_common import DataCache, InvalidSessionError
from trezor import utils from trezor import utils
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeVar from typing import TypeVar # pyright: ignore[reportShadowedImports]
T = TypeVar("T") T = TypeVar("T")

@ -1,6 +1,6 @@
import gc import gc
import sys import sys
from trezorutils import ( # noqa: F401 from trezorutils import ( # noqa: F401 # pyright: ignore[reportMissingImports]
BITCOIN_ONLY, BITCOIN_ONLY,
EMULATOR, EMULATOR,
INTERNAL_MODEL, INTERNAL_MODEL,
@ -25,7 +25,7 @@ from trezorutils import ( # noqa: F401
unit_color, unit_color,
unit_packaging, unit_packaging,
) )
from typing import TYPE_CHECKING from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
# Will get replaced by "True" / "False" in the build process # Will get replaced by "True" / "False" in the build process
# However, needs to stay as an exported symbol for the unit tests # However, needs to stay as an exported symbol for the unit tests
@ -37,7 +37,7 @@ USE_THP = True # TODO move elsewhere, probably to core/embed/trezorhal/...
if __debug__: if __debug__:
if EMULATOR: if EMULATOR:
import uos import uos # pyright: ignore[reportMissingModuleSource]
DISABLE_ANIMATION = int(uos.getenv("TREZOR_DISABLE_ANIMATION") or "0") DISABLE_ANIMATION = int(uos.getenv("TREZOR_DISABLE_ANIMATION") or "0")
LOG_MEMORY = int(uos.getenv("TREZOR_LOG_MEMORY") or "0") LOG_MEMORY = int(uos.getenv("TREZOR_LOG_MEMORY") or "0")
@ -45,7 +45,13 @@ if __debug__:
LOG_MEMORY = 0 LOG_MEMORY = 0
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Iterator, Protocol, Sequence, TypeVar from typing import ( # pyright: ignore[reportShadowedImports]
Any,
Iterator,
Protocol,
Sequence,
TypeVar,
)
from trezor.protobuf import MessageType from trezor.protobuf import MessageType
@ -113,13 +119,13 @@ def presize_module(modname: str, size: int) -> None:
if __debug__: if __debug__:
def mem_dump(filename: str) -> None: def mem_dump(filename: str) -> None:
from micropython import mem_info from micropython import mem_info # pyright: ignore[reportMissingModuleSource]
print(f"### sysmodules ({len(sys.modules)}):") print(f"### sysmodules ({len(sys.modules)}):")
for mod in sys.modules: for mod in sys.modules:
print("*", mod) print("*", mod)
if EMULATOR: if EMULATOR:
from trezorutils import meminfo from trezorutils import meminfo # pyright: ignore[reportMissingImports]
print("### dumping to", filename) print("### dumping to", filename)
meminfo(filename) meminfo(filename)

@ -30,7 +30,7 @@ from storage.cache_common import InvalidSessionError
from trezor import log, loop, protobuf, utils from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType from trezor.enums import FailureType
from trezor.messages import Failure from trezor.messages import Failure
from trezor.wire import codec_v1, context, message_handler, protocol_common, thp_v1 from trezor.wire import context, message_handler, protocol_common, thp_v1
from trezor.wire.errors import DataError, Error from trezor.wire.errors import DataError, Error
# Import all errors into namespace, so that `wire.Error` is available from # Import all errors into namespace, so that `wire.Error` is available from
@ -63,7 +63,7 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
if utils.USE_THP: if utils.USE_THP:
loop.schedule(handle_thp_session(iface, is_debug_session)) loop.schedule(handle_thp_session(iface, is_debug_session))
else: else:
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session)) loop.schedule(handle_session(iface, is_debug_session))
def wrap_protobuf_load( def wrap_protobuf_load(
@ -128,15 +128,12 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
print("Exception raised:", exc) print("Exception raised:", exc)
async def handle_session( async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:
iface: WireInterface, codec_session_id: int, is_debug_session: bool = False
) -> None:
if __debug__ and is_debug_session: if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG ctx_buffer = WIRE_BUFFER_DEBUG
else: else:
ctx_buffer = WIRE_BUFFER ctx_buffer = WIRE_BUFFER
session_id = codec_session_id.to_bytes(4, "big") ctx = context.CodecContext(iface, ctx_buffer)
ctx = context.CodecContext(iface, ctx_buffer, session_id)
next_msg: protocol_common.MessageWithId | None = None next_msg: protocol_common.MessageWithId | None = None
if __debug__ and is_debug_session: if __debug__ and is_debug_session:
@ -165,10 +162,6 @@ async def handle_session(
msg = next_msg msg = next_msg
next_msg = None next_msg = None
# Set ctx.session_id to the value msg.session_id
if msg.session_id is not None:
ctx.channel_id = msg.session_id
try: try:
next_msg_without_id = await message_handler.handle_single_message( next_msg_without_id = await message_handler.handle_single_message(
ctx, msg, use_workflow=not is_debug_session ctx, msg, use_workflow=not is_debug_session

@ -65,12 +65,10 @@ class CodecContext(Context):
self, self,
iface: WireInterface, iface: WireInterface,
buffer: bytearray, buffer: bytearray,
channel_id: bytes,
) -> None: ) -> None:
self.iface = iface self.iface = iface
self.buffer = buffer self.buffer = buffer
self.channel_id = channel_id super().__init__(iface, codec_v1.SESSION_ID.to_bytes(2, "big"))
super().__init__(iface, channel_id)
def read_from_wire(self) -> Awaitable[MessageWithId]: def read_from_wire(self) -> Awaitable[MessageWithId]:
"""Read a whole message from the wire without parsing it.""" """Read a whole message from the wire without parsing it."""
@ -100,15 +98,10 @@ class CodecContext(Context):
to save on having to decode the type code into a protobuf class. to save on having to decode the type code into a protobuf class.
""" """
if __debug__: if __debug__:
if self.channel_id is not None:
sid = int.from_bytes(self.channel_id, "big")
else:
sid = -1
log.debug( log.debug(
__name__, __name__,
"%s:%x expect: %s", "%s: expect: %s",
self.iface.iface_num(), self.iface.iface_num(),
sid,
expected_type.MESSAGE_NAME if expected_type else expected_types, expected_type.MESSAGE_NAME if expected_type else expected_types,
) )
@ -120,22 +113,14 @@ class CodecContext(Context):
if msg.type not in expected_types: if msg.type not in expected_types:
raise UnexpectedMessageWithId(msg) raise UnexpectedMessageWithId(msg)
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
# (and maybe update ctx.session_id - depends on expected behaviour)
if expected_type is None: if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type) expected_type = protobuf.type_for_wire(msg.type)
if __debug__: if __debug__:
if self.channel_id is not None:
sid = int.from_bytes(self.channel_id, "big")
else:
sid = -1
log.debug( log.debug(
__name__, __name__,
"%s:%x read: %s", "%s: read: %s",
self.iface.iface_num(), self.iface.iface_num(),
sid,
expected_type.MESSAGE_NAME, expected_type.MESSAGE_NAME,
) )
@ -147,15 +132,10 @@ class CodecContext(Context):
async def write(self, msg: protobuf.MessageType) -> None: async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire.""" """Write a message to the wire."""
if __debug__: if __debug__:
if self.channel_id is not None:
sid = int.from_bytes(self.channel_id, "big")
else:
sid = -1
log.debug( log.debug(
__name__, __name__,
"%s:%x write: %s", "%s: write: %s",
self.iface.iface_num(), self.iface.iface_num(),
sid,
msg.MESSAGE_NAME, msg.MESSAGE_NAME,
) )

@ -1,4 +1,4 @@
from common import * # isort:skip from common import * # isort:skip # noqa: F403
from mock_storage import mock_storage from mock_storage import mock_storage
@ -23,7 +23,9 @@ def is_session_started() -> bool:
return _PROTOCOL_CACHE.get_active_session() is not None return _PROTOCOL_CACHE.get_active_session() is not None
class TestStorageCache(unittest.TestCase): class TestStorageCache(
unittest.TestCase
): # noqa: F405 # pyright: ignore[reportUndefinedVariable]
def setUp(self): def setUp(self):
cache.clear_all() cache.clear_all()

@ -26,12 +26,12 @@ class TestCryptoBase32(unittest.TestCase):
"PJWHK5DPOVRWW6JANN2W4IDVOBSWYIDEMFRGK3DTNNSSA33EPE======", "PJWHK5DPOVRWW6JANN2W4IDVOBSWYIDEMFRGK3DTNNSSA33EPE======",
), ),
# fmt: off # fmt: off
(b"中文", "4S4K3ZUWQ4======"), # noqa:E999 (b"中文", "4S4K3ZUWQ4======"), # noqa: E999
(b"中文1", "4S4K3ZUWQ4YQ===="), # noqa:E999 (b"中文1", "4S4K3ZUWQ4YQ===="), # noqa: E999
(b"中文12", "4S4K3ZUWQ4YTE==="), # noqa:E999 (b"中文12", "4S4K3ZUWQ4YTE==="), # noqa: E999
(b"aécio", "MHB2SY3JN4======"), # noqa:E999 (b"aécio", "MHB2SY3JN4======"), # noqa: E999
(b"𠜎", "6CQJZDQ="), # noqa:E999 (b"𠜎", "6CQJZDQ="), # noqa: E999
(b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa:E999 (b"Base64是一種基於64個可列印字元來表示二進制資料的表示方法", # noqa: E999
"IJQXGZJWGTTJRL7EXCAOPKFO4WP3VZUWXQ3DJZMARPSY7L7FRCL6LDNQ4WWZPZMFQPSL5BXIUGUOPJF24S5IZ2MAWLSYRNXIWOD6NFUZ46NIJ2FBVDT2JOXGS246NM4V"), "IJQXGZJWGTTJRL7EXCAOPKFO4WP3VZUWXQ3DJZMARPSY7L7FRCL6LDNQ4WWZPZMFQPSL5BXIUGUOPJF24S5IZ2MAWLSYRNXIWOD6NFUZ46NIJ2FBVDT2JOXGS246NM4V"),
# fmt: on # fmt: on
] ]

@ -79,7 +79,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
def setUp(self): def setUp(self):
self.interface = MockHID(0xDEADBEEF) self.interface = MockHID(0xDEADBEEF)
if not utils.USE_THP: if not utils.USE_THP:
import storage.cache_thp # noQA:F401 import storage.cache_thp # noqa: F401
def test_simple(self): def test_simple(self):
cid_req_header = make_header( cid_req_header = make_header(

Loading…
Cancel
Save