From 0aaa255a0cabef6c677ee5c5ddb8a0543db7e36a Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 10 Apr 2024 12:18:10 +0200 Subject: [PATCH] Adjust imports --- core/src/apps/thp/create_session.py | 2 +- core/src/trezor/wire/thp/session_context.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/core/src/apps/thp/create_session.py b/core/src/apps/thp/create_session.py index 7e4cded0d..8f62e5375 100644 --- a/core/src/apps/thp/create_session.py +++ b/core/src/apps/thp/create_session.py @@ -1,12 +1,12 @@ from trezor import log, loop from trezor.messages import ThpCreateNewSession, ThpNewSession from trezor.wire.thp import SessionState, channel -from trezor.wire.thp.session_context import SessionContext async def create_new_session( channel: channel.Channel, message: ThpCreateNewSession ) -> ThpNewSession: + from trezor.wire.thp.session_context import SessionContext session = SessionContext.create_new_session(channel) session.set_session_state(SessionState.ALLOCATED) diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 0ffc6a39b..0b3664165 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -8,7 +8,7 @@ from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure from ..protocol_common import Context, MessageWithType from . import SessionState -from . import channel +from .channel import Channel if TYPE_CHECKING: from typing import Container # pyright: ignore[reportShadowedImports] @@ -29,9 +29,7 @@ class UnexpectedMessageWithType(Exception): class SessionContext(Context): - def __init__( - self, channel: channel.Channel, session_cache: SessionThpCache - ) -> None: + def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None: if channel.channel_id != session_cache.channel_id: raise Exception( "The session has different channel id than the provided channel context!" @@ -43,7 +41,7 @@ class SessionContext(Context): self.incoming_message = loop.chan() @classmethod - def create_new_session(cls, channel_context: channel.Channel) -> "SessionContext": + def create_new_session(cls, channel_context: Channel) -> "SessionContext": session_cache = cache_thp.get_new_session(channel_context.channel_cache) return cls(channel_context, session_cache) @@ -147,7 +145,7 @@ class SessionContext(Context): self.session_cache.state = bytearray(state.to_bytes(1, "big")) -def load_cached_sessions(channel: channel.Channel) -> dict[int, SessionContext]: # TODO +def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO if __debug__: log.debug(__name__, "load_cached_sessions") sessions: dict[int, SessionContext] = {}