1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): implement SessionContext

This commit is contained in:
M1nd3r 2024-04-27 02:23:34 +02:00
parent 5bdd2e7fa5
commit 3aa5b88a0d
7 changed files with 85 additions and 33 deletions

View File

@ -213,8 +213,8 @@ trezor.wire.thp
import trezor.wire.thp
trezor.wire.thp.ack_handler
import trezor.wire.thp.ack_handler
trezor.wire.thp.channel_context
import trezor.wire.thp.channel_context
trezor.wire.thp.channel
import trezor.wire.thp.channel
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.session_context

View File

@ -42,7 +42,7 @@ if TYPE_CHECKING:
T = TypeVar("T")
class UnexpectedMessage(Exception):
class UnexpectedMessageWithId(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
@ -118,7 +118,7 @@ class CodecContext(Context):
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
raise UnexpectedMessageWithId(msg)
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
# (and maybe update ctx.session_id - depends on expected behaviour)

View File

@ -137,7 +137,7 @@ async def handle_single_message(
# results of the handler.
res_msg = await task
except context.UnexpectedMessage as exc:
except context.UnexpectedMessageWithId as exc:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises

View File

@ -4,6 +4,8 @@ from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface
from trezorio import WireInterface
from typing import Container
class Message:
@ -46,6 +48,12 @@ class Context:
self.iface: WireInterface = iface
self.channel_id: bytes = channel_id
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType: ...
async def write(self, msg: protobuf.MessageType) -> None: ...

View File

@ -10,7 +10,7 @@ from trezor import loop, protobuf, utils
from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler
from ..protocol_common import Context
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
@ -39,7 +39,7 @@ REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
class ChannelContext(Context):
class Channel(Context):
def __init__(self, channel_cache: ChannelCache) -> None:
iface = _decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache.channel_id)
@ -56,7 +56,7 @@ class ChannelContext(Context):
@classmethod
def create_new_channel(
cls, iface: WireInterface, buffer: utils.BufferType
) -> "ChannelContext":
) -> "Channel":
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
r = cls(channel_cache)
r.set_buffer(buffer)
@ -217,9 +217,11 @@ class ChannelContext(Context):
if session_state is SessionState.UNALLOCATED:
raise Exception("Unalloacted session")
await self.sessions[session_id].receive_message(
message_type,
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
self.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
)
)
if state is ChannelState.TH2:
@ -275,6 +277,7 @@ class ChannelContext(Context):
session = SessionContext.create_new_session(self)
print("help")
self.sessions[session.session_id] = session
loop.schedule(session.handle())
print("new session created. Session id:", session.session_id)
def _todo_clear_buffer(self):
@ -300,11 +303,11 @@ class ChannelContext(Context):
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
def load_cached_channels(buffer: utils.BufferType) -> dict[int, ChannelContext]: # TODO
channels: dict[int, ChannelContext] = {}
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c)
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels

View File

@ -1,33 +1,74 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage import cache_thp
from storage.cache_thp import SessionThpCache
from trezor import protobuf
from trezor import loop, protobuf
from trezor.wire import message_handler
from ..protocol_common import Context
from ..protocol_common import Context, MessageWithType
from . import SessionState
from .channel_context import ChannelContext
from .channel import Channel
if TYPE_CHECKING:
from typing import Container # pyright: ignore[reportShadowedImports]
pass
class UnexpectedMessageWithType(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: MessageWithType) -> None:
super().__init__()
self.msg = msg
class SessionContext(Context):
def __init__(
self, channel_context: ChannelContext, session_cache: SessionThpCache
) -> None:
if channel_context.channel_id != session_cache.channel_id:
def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None:
if channel.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
super().__init__(channel_context.iface, channel_context.channel_id)
self.channel_context = channel_context
super().__init__(channel.iface, channel.channel_id)
self.channel_context = channel
self.session_cache = session_cache
self.session_id = int.from_bytes(session_cache.session_id, "big")
self.incoming_message = loop.chan()
@classmethod
def create_new_session(cls, channel_context: Channel) -> "SessionContext":
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache)
async def handle(self) -> None:
take = self.incoming_message.take()
while True:
message = await take
print(message)
# TODO continue similarly to handle_session function in wire.__init__
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
message: MessageWithType = await self.incoming_message.take()
if message.type not in expected_types:
raise UnexpectedMessageWithType(message)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_context.write(msg, self.session_id)
@classmethod
def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext":
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache)
# ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState:
@ -43,7 +84,7 @@ class SessionContext(Context):
pass # TODO implement
def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
for session in cached_sessions:

View File

@ -8,12 +8,12 @@ from trezor import io, log, loop, utils
from .protocol_common import MessageWithId
from .thp import ChannelState, ack_handler, checksum, thp_messages
from .thp import thp_session as THP
from .thp.channel_context import (
from .thp.channel import (
CONT_DATA_OFFSET,
INIT_DATA_OFFSET,
MAX_PAYLOAD_LEN,
REPORT_LENGTH,
ChannelContext,
Channel,
load_cached_channels,
)
from .thp.checksum import CHECKSUM_LENGTH
@ -38,7 +38,7 @@ _PLAINTEXT = 0x01
_BUFFER: bytearray
_BUFFER_LOCK = None
_CHANNEL_CONTEXTS: dict[int, ChannelContext] = {}
_CHANNEL_CONTEXTS: dict[int, Channel] = {}
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
@ -346,7 +346,7 @@ async def _handle_broadcast(
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid")
new_context: ChannelContext = ChannelContext.create_new_channel(iface, _BUFFER)
new_context: Channel = Channel.create_new_channel(iface, _BUFFER)
cid = int.from_bytes(new_context.channel_id, "big")
_CHANNEL_CONTEXTS[cid] = new_context