feat(core): add management session

M1nd3r/thp7
M1nd3r 2 weeks ago
parent 06b1784299
commit 4a201ab079

@ -1,44 +1,51 @@
from typing import TYPE_CHECKING
from trezor import log, loop
from trezor.enums import FailureType
from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
from trezor.enums import ButtonRequestType, FailureType
from trezor.messages import (
ButtonAck,
ButtonRequest,
Failure,
ThpCreateNewSession,
ThpNewSession,
)
from trezor.wire.context import call
from trezor.wire.errors import ActionCancelled, DataError
from trezor.wire.thp import SessionState
if TYPE_CHECKING:
from trezor.wire.thp import ChannelContext
from trezor.wire.thp.session_context import ManagementSessionContext
async def create_new_session(
channel: ChannelContext, message: ThpCreateNewSession
management_session: ManagementSessionContext, message: ThpCreateNewSession
) -> ThpNewSession | Failure:
from trezor.wire.thp.session_manager import create_new_session
from apps.common.seed import derive_and_store_roots
session = create_new_session(channel)
new_session = create_new_session(management_session.channel_ctx)
try:
await derive_and_store_roots(session, message)
await derive_and_store_roots(new_session, message)
except DataError as e:
return Failure(code=FailureType.DataError, message=e.message)
except ActionCancelled as e:
return Failure(code=FailureType.ActionCancelled, message=e.message)
# TODO handle other errors
session.set_session_state(SessionState.ALLOCATED)
channel.sessions[session.session_id] = session
loop.schedule(session.handle())
new_session_id: int = session.session_id
new_session.set_session_state(SessionState.ALLOCATED)
management_session.channel_ctx.sessions[new_session.session_id] = new_session
loop.schedule(new_session.handle())
new_session_id: int = new_session.session_id
# await get_seed() TODO
if __debug__:
log.debug(
__name__,
"create_new_session - new session created. Passphrase: %s, Session id: %d",
"create_new_session - new session created. Passphrase: %s, Session id: %d\n%s",
message.passphrase if message.passphrase is not None else "",
session.session_id,
new_session.session_id,
str(management_session.channel_ctx.sessions),
)
print(channel.sessions)
return ThpNewSession(new_session_id=new_session_id)

@ -27,6 +27,7 @@ BROADCAST_CHANNEL_ID = const(65535)
KEY_LENGTH = const(32)
TAG_LENGTH = const(16)
_UNALLOCATED_STATE = const(0)
MANAGEMENT_SESSION_ID = const(0)
class ConnectionCache(DataCache):
@ -157,13 +158,18 @@ def get_all_allocated_channels() -> list[ChannelCache]:
def get_all_allocated_sessions() -> list[SessionThpCache]:
if __debug__:
from trezor.utils import get_bytes_as_str
_list: list[SessionThpCache] = []
for session in _SESSIONS:
if _get_session_state(session) != _UNALLOCATED_STATE:
_list.append(session)
if __debug__:
log.debug(
__name__, "session %s is not in UNALLOCATED state", str(session)
__name__,
"session with channel_id: %s and session_id: %s is in ALLOCATED state",
get_bytes_as_str(session.channel_id),
get_bytes_as_str(session.session_id),
)
elif __debug__:
log.debug(__name__, "session %s is in UNALLOCATED state", str(session))

@ -21,6 +21,7 @@ if TYPE_CHECKING:
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[[Msg], HandlerTask]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
HandlerFinder = Callable[[Any, Any], Handler | None]
# If set to False protobuf messages marked with "experimental_message" option are rejected.
@ -57,7 +58,10 @@ if __debug__:
async def handle_single_message(
ctx: context.Context, msg: protocol_common.MessageWithType, use_workflow: bool
ctx: context.Context,
msg: protocol_common.MessageWithType,
use_workflow: bool,
special_handler_finder: HandlerFinder | None = None,
) -> protocol_common.MessageWithType | None:
"""Handle a message that was loaded from USB by the caller.
@ -98,7 +102,12 @@ async def handle_single_message(
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type. Should not raise.
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
if special_handler_finder is not None:
handler: Handler | None = special_handler_finder(ctx.iface, msg.type)
else:
handler: Handler | None = find_handler( # pylint: disable=assignment-from-none
ctx.iface, msg.type
)
if handler is None:
# If no handler is found, we can skip decoding and directly
@ -175,7 +184,7 @@ def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler |
return None
find_handler = _find_handler_placeholder
find_handler: HandlerFinder = _find_handler_placeholder
AVOID_RESTARTING_FOR: Container[int] = ()

@ -3,13 +3,13 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING:
from enum import IntEnum
from trezorio import WireInterface
from typing import Protocol, TypeVar, overload
from typing import List, Protocol, TypeVar, overload
from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils
from trezor.enums import FailureType
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.session_context import SessionContext
from trezor.wire.thp.session_context import GenericSessionContext
T = TypeVar("T")
@ -18,8 +18,8 @@ if TYPE_CHECKING:
iface: WireInterface
channel_id: bytes
channel_cache: ChannelCache
selected_pairing_methods = [] # TODO add type
sessions: dict[int, SessionContext]
selected_pairing_methods: List[int] = [] # TODO add type
sessions: dict[int, GenericSessionContext]
waiting_for_ack_timeout: loop.spawn | None
write_task_spawn: loop.spawn | None
connection_context: PairingContext | None
@ -73,6 +73,7 @@ class ChannelState(IntEnum):
class SessionState(IntEnum):
UNALLOCATED = 0
ALLOCATED = 1
MANAGEMENT = 2
class WireInterfaceType(IntEnum):

@ -4,9 +4,17 @@ from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
from storage.cache_thp import TAG_LENGTH, ChannelCache
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.wire.thp import interface_manager, received_message_handler
from . import ChannelState, checksum, control_byte, crypto, memory_manager
from . import (
ChannelState,
checksum,
control_byte,
crypto,
interface_manager,
memory_manager,
received_message_handler,
session_manager,
)
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
@ -25,10 +33,8 @@ if TYPE_CHECKING:
from trezorio import WireInterface
from typing import TypeVar, overload
from . import ChannelContext, PairingContext
from .session_context import SessionContext
else:
ChannelContext = object
from . import PairingContext
from .session_context import GenericSessionContext
class Channel:
@ -43,10 +49,11 @@ class Channel:
self.buffer: utils.BufferType
self.channel_id: bytes = channel_cache.channel_id
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.sessions: dict[int, GenericSessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
self._create_management_session()
# ACCESS TO CHANNEL_DATA
def get_channel_id_int(self) -> int:
@ -68,6 +75,11 @@ class Channel:
if __debug__:
log.debug(__name__, "set_buffer: %s", type(self.buffer))
def _create_management_session(self) -> None:
session = session_manager.create_new_management_session(self)
self.sessions[session.session_id] = session
loop.schedule(session.handle())
# CALLED BY THP_MAIN_LOOP
async def receive_packet(self, packet: utils.BufferType):

@ -8,13 +8,13 @@ from apps.base import get_features
from apps.thp import create_session
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Coroutine
from trezor.messages import Features, GetFeatures, LoadDevice
from . import ChannelContext
pass
from .. import Handler
from .session_context import ManagementSessionContext
def get_handler_for_channel_message(
@ -29,7 +29,7 @@ def get_handler_for_channel_message(
from apps.debug.load_device import load_device
def wrapper(
channel: ChannelContext, msg: LoadDevice
channel: ManagementSessionContext, msg: LoadDevice
) -> Coroutine[Any, Any, protobuf.MessageType]:
return load_device(msg)
@ -37,5 +37,18 @@ def get_handler_for_channel_message(
raise UnexpectedMessage("There is no handler available for this message")
async def handle_GetFeatures(ctx: ChannelContext, msg: GetFeatures) -> Features:
async def handle_GetFeatures(
ctx: ManagementSessionContext, msg: GetFeatures
) -> Features:
return get_features()
def get_handler_finder_for_message(ctx: ManagementSessionContext):
def finder(iface: WireInterface, msg_type: int) -> Handler | None:
def handler_wrap(msg: protobuf.MessageType):
handler = get_handler_for_channel_message(msg)
return handler(ctx, msg)
return handler_wrap
return finder

@ -260,9 +260,6 @@ async def _handle_state_ENCRYPTED_TRANSPORT(
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", ctx.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
await _handle_channel_message(ctx, message_length, message_type)
return
if session_id not in ctx.sessions:
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
raise ThpError("Unalloacted session")

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage.cache_thp import SessionThpCache
from storage.cache_thp import MANAGEMENT_SESSION_ID, SessionThpCache
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
@ -19,6 +19,7 @@ if TYPE_CHECKING:
from storage.cache_common import DataCache
from ..message_handler import HandlerFinder
from . import ChannelContext
pass
@ -26,6 +27,9 @@ if TYPE_CHECKING:
_EXIT_LOOP = True
_REPEAT_LOOP = False
if __debug__:
from trezor.utils import get_bytes_as_str
class UnexpectedMessageWithType(Exception):
"""A message was received that is not part of the current workflow.
@ -39,19 +43,13 @@ class UnexpectedMessageWithType(Exception):
self.msg = msg
class SessionContext(Context):
def __init__(
self, channel_ctx: ChannelContext, session_cache: SessionThpCache
) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
class GenericSessionContext(Context):
def __init__(self, channel_ctx: ChannelContext, session_id: int) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx = channel_ctx
self.session_cache = session_cache
self.session_id = int.from_bytes(session_cache.session_id, "big")
self.channel_ctx: ChannelContext = channel_ctx
self.session_id: int = session_id
self.incoming_message = loop.chan()
self.special_handler_finder: HandlerFinder | None = None
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__:
@ -73,7 +71,12 @@ class SessionContext(Context):
log.exception(__name__, exc)
def _handle_debug(self, is_debug_session: bool) -> None:
log.debug(__name__, "handle - start (session_id: %d)", self.session_id)
log.debug(
__name__,
"handle - start (channel_id (bytes): %s, session_id: %d)",
get_bytes_as_str(self.channel_id),
self.session_id,
)
if is_debug_session:
import apps.debug
@ -96,7 +99,10 @@ class SessionContext(Context):
try:
next_message = await message_handler.handle_single_message(
self, message, use_workflow=not is_debug_session
self,
message,
use_workflow=not is_debug_session,
special_handler_finder=self.special_handler_finder,
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
@ -156,6 +162,33 @@ class SessionContext(Context):
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_ctx.write(msg, self.session_id)
def get_session_state(self) -> SessionState: ...
class ManagementSessionContext(GenericSessionContext):
def __init__(self, channel_ctx: ChannelContext) -> None:
super().__init__(channel_ctx, MANAGEMENT_SESSION_ID)
from trezor.wire.thp.handler_provider import get_handler_finder_for_message
finder = get_handler_finder_for_message(self)
self.special_handler_finder = finder
def get_session_state(self) -> SessionState:
return SessionState.MANAGEMENT
class SessionContext(GenericSessionContext):
def __init__(
self, channel_ctx: ChannelContext, session_cache: SessionThpCache
) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
session_id = int.from_bytes(session_cache.session_id, "big")
super().__init__(channel_ctx, session_id)
self.session_cache = session_cache
# ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState:

@ -3,7 +3,11 @@ from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import loop
from .session_context import SessionContext
from .session_context import (
GenericSessionContext,
ManagementSessionContext,
SessionContext,
)
if __debug__:
from trezor import log
@ -17,10 +21,18 @@ def create_new_session(channel_ctx: ChannelContext) -> SessionContext:
return SessionContext(channel_ctx, session_cache)
def load_cached_sessions(channel_ctx: ChannelContext) -> dict[int, SessionContext]:
def create_new_management_session(
channel_ctx: ChannelContext,
) -> ManagementSessionContext:
return ManagementSessionContext(channel_ctx)
def load_cached_sessions(
channel_ctx: ChannelContext,
) -> dict[int, GenericSessionContext]:
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {}
sessions: dict[int, GenericSessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
if __debug__:
log.debug(

@ -35,7 +35,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
global _BUFFER
CHANNELS = channel_manager.load_cached_channels(_BUFFER)
for ch in CHANNELS.values():
ch.sessions = session_manager.load_cached_sessions(ch)
ch.sessions.update(session_manager.load_cached_sessions(ch))
read = loop.wait(iface.iface_num() | io.POLL_READ)

Loading…
Cancel
Save