mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): implement SessionContext
This commit is contained in:
parent
5bdd2e7fa5
commit
3aa5b88a0d
4
core/src/all_modules.py
generated
4
core/src/all_modules.py
generated
@ -213,8 +213,8 @@ trezor.wire.thp
|
||||
import trezor.wire.thp
|
||||
trezor.wire.thp.ack_handler
|
||||
import trezor.wire.thp.ack_handler
|
||||
trezor.wire.thp.channel_context
|
||||
import trezor.wire.thp.channel_context
|
||||
trezor.wire.thp.channel
|
||||
import trezor.wire.thp.channel
|
||||
trezor.wire.thp.checksum
|
||||
import trezor.wire.thp.checksum
|
||||
trezor.wire.thp.session_context
|
||||
|
@ -42,7 +42,7 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class UnexpectedMessage(Exception):
|
||||
class UnexpectedMessageWithId(Exception):
|
||||
"""A message was received that is not part of the current workflow.
|
||||
|
||||
Utility exception to inform the session handler that the current workflow
|
||||
@ -118,7 +118,7 @@ class CodecContext(Context):
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessage(msg)
|
||||
raise UnexpectedMessageWithId(msg)
|
||||
|
||||
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
|
||||
# (and maybe update ctx.session_id - depends on expected behaviour)
|
||||
|
@ -137,7 +137,7 @@ async def handle_single_message(
|
||||
# results of the handler.
|
||||
res_msg = await task
|
||||
|
||||
except context.UnexpectedMessage as exc:
|
||||
except context.UnexpectedMessageWithId as exc:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
|
@ -4,6 +4,8 @@ from trezor import protobuf
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from trezorio import WireInterface
|
||||
from typing import Container
|
||||
|
||||
|
||||
class Message:
|
||||
@ -46,6 +48,12 @@ class Context:
|
||||
self.iface: WireInterface = iface
|
||||
self.channel_id: bytes = channel_id
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||
|
||||
|
||||
|
@ -10,7 +10,7 @@ from trezor import loop, protobuf, utils
|
||||
from trezor.messages import ThpCreateNewSession
|
||||
from trezor.wire import message_handler
|
||||
|
||||
from ..protocol_common import Context
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
from . import ChannelState, SessionState, checksum
|
||||
from . import thp_session as THP
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
@ -39,7 +39,7 @@ REPORT_LENGTH = const(64)
|
||||
MAX_PAYLOAD_LEN = const(60000)
|
||||
|
||||
|
||||
class ChannelContext(Context):
|
||||
class Channel(Context):
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
iface = _decode_iface(channel_cache.iface)
|
||||
super().__init__(iface, channel_cache.channel_id)
|
||||
@ -56,7 +56,7 @@ class ChannelContext(Context):
|
||||
@classmethod
|
||||
def create_new_channel(
|
||||
cls, iface: WireInterface, buffer: utils.BufferType
|
||||
) -> "ChannelContext":
|
||||
) -> "Channel":
|
||||
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
|
||||
r = cls(channel_cache)
|
||||
r.set_buffer(buffer)
|
||||
@ -217,9 +217,11 @@ class ChannelContext(Context):
|
||||
if session_state is SessionState.UNALLOCATED:
|
||||
raise Exception("Unalloacted session")
|
||||
|
||||
await self.sessions[session_id].receive_message(
|
||||
message_type,
|
||||
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
||||
self.sessions[session_id].incoming_message.publish(
|
||||
MessageWithType(
|
||||
message_type,
|
||||
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
||||
)
|
||||
)
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
@ -275,6 +277,7 @@ class ChannelContext(Context):
|
||||
session = SessionContext.create_new_session(self)
|
||||
print("help")
|
||||
self.sessions[session.session_id] = session
|
||||
loop.schedule(session.handle())
|
||||
print("new session created. Session id:", session.session_id)
|
||||
|
||||
def _todo_clear_buffer(self):
|
||||
@ -300,11 +303,11 @@ class ChannelContext(Context):
|
||||
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
|
||||
|
||||
|
||||
def load_cached_channels(buffer: utils.BufferType) -> dict[int, ChannelContext]: # TODO
|
||||
channels: dict[int, ChannelContext] = {}
|
||||
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
|
||||
channels: dict[int, Channel] = {}
|
||||
cached_channels = cache_thp.get_all_allocated_channels()
|
||||
for c in cached_channels:
|
||||
channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c)
|
||||
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
|
||||
for c in channels.values():
|
||||
c.set_buffer(buffer)
|
||||
return channels
|
@ -1,33 +1,74 @@
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import protobuf
|
||||
from trezor import loop, protobuf
|
||||
from trezor.wire import message_handler
|
||||
|
||||
from ..protocol_common import Context
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
from . import SessionState
|
||||
from .channel_context import ChannelContext
|
||||
from .channel import Channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Container # pyright: ignore[reportShadowedImports]
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedMessageWithType(Exception):
|
||||
"""A message was received that is not part of the current workflow.
|
||||
|
||||
Utility exception to inform the session handler that the current workflow
|
||||
should be aborted and a new one started as if `msg` was the first message.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: MessageWithType) -> None:
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class SessionContext(Context):
|
||||
def __init__(
|
||||
self, channel_context: ChannelContext, session_cache: SessionThpCache
|
||||
) -> None:
|
||||
if channel_context.channel_id != session_cache.channel_id:
|
||||
def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None:
|
||||
if channel.channel_id != session_cache.channel_id:
|
||||
raise Exception(
|
||||
"The session has different channel id than the provided channel context!"
|
||||
)
|
||||
super().__init__(channel_context.iface, channel_context.channel_id)
|
||||
self.channel_context = channel_context
|
||||
super().__init__(channel.iface, channel.channel_id)
|
||||
self.channel_context = channel
|
||||
self.session_cache = session_cache
|
||||
self.session_id = int.from_bytes(session_cache.session_id, "big")
|
||||
self.incoming_message = loop.chan()
|
||||
|
||||
@classmethod
|
||||
def create_new_session(cls, channel_context: Channel) -> "SessionContext":
|
||||
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
|
||||
return cls(channel_context, session_cache)
|
||||
|
||||
async def handle(self) -> None:
|
||||
take = self.incoming_message.take()
|
||||
while True:
|
||||
message = await take
|
||||
print(message)
|
||||
# TODO continue similarly to handle_session function in wire.__init__
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
|
||||
message: MessageWithType = await self.incoming_message.take()
|
||||
if message.type not in expected_types:
|
||||
raise UnexpectedMessageWithType(message)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel_context.write(msg, self.session_id)
|
||||
|
||||
@classmethod
|
||||
def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext":
|
||||
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
|
||||
return cls(channel_context, session_cache)
|
||||
|
||||
# ACCESS TO SESSION DATA
|
||||
|
||||
def get_session_state(self) -> SessionState:
|
||||
@ -43,7 +84,7 @@ class SessionContext(Context):
|
||||
pass # TODO implement
|
||||
|
||||
|
||||
def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO
|
||||
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
|
||||
sessions: dict[int, SessionContext] = {}
|
||||
cached_sessions = cache_thp.get_all_allocated_sessions()
|
||||
for session in cached_sessions:
|
||||
|
@ -8,12 +8,12 @@ from trezor import io, log, loop, utils
|
||||
from .protocol_common import MessageWithId
|
||||
from .thp import ChannelState, ack_handler, checksum, thp_messages
|
||||
from .thp import thp_session as THP
|
||||
from .thp.channel_context import (
|
||||
from .thp.channel import (
|
||||
CONT_DATA_OFFSET,
|
||||
INIT_DATA_OFFSET,
|
||||
MAX_PAYLOAD_LEN,
|
||||
REPORT_LENGTH,
|
||||
ChannelContext,
|
||||
Channel,
|
||||
load_cached_channels,
|
||||
)
|
||||
from .thp.checksum import CHECKSUM_LENGTH
|
||||
@ -38,7 +38,7 @@ _PLAINTEXT = 0x01
|
||||
_BUFFER: bytearray
|
||||
_BUFFER_LOCK = None
|
||||
|
||||
_CHANNEL_CONTEXTS: dict[int, ChannelContext] = {}
|
||||
_CHANNEL_CONTEXTS: dict[int, Channel] = {}
|
||||
|
||||
|
||||
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
||||
@ -346,7 +346,7 @@ async def _handle_broadcast(
|
||||
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
new_context: ChannelContext = ChannelContext.create_new_channel(iface, _BUFFER)
|
||||
new_context: Channel = Channel.create_new_channel(iface, _BUFFER)
|
||||
cid = int.from_bytes(new_context.channel_id, "big")
|
||||
_CHANNEL_CONTEXTS[cid] = new_context
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user