Clean session creation

M1nd3r/thp5
M1nd3r 1 month ago
parent 38b8e71640
commit 601834d233

@ -690,6 +690,8 @@ if FROZEN:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/tezos/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/Tezos*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/zcash/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/webauthn/*.py'))

@ -776,6 +776,8 @@ if FROZEN:
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/tezos/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/Tezos*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/zcash/*.py'))
SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/webauthn/*.py'))

@ -217,6 +217,8 @@ trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.crypto
import trezor.wire.thp.crypto
trezor.wire.thp.handler_provider
import trezor.wire.thp.handler_provider
trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context
trezor.wire.thp.session_context

@ -1,13 +1,26 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from trezor.wire.thp.channel import Channel
if TYPE_CHECKING:
from trezor.messages import ThpCreateNewSession, ThpNewSession
from trezor import log, loop
from trezor.messages import ThpCreateNewSession, ThpNewSession
from trezor.wire.thp import SessionState, channel
from trezor.wire.thp.session_context import SessionContext
async def create_new_session(
channel: Channel, message: ThpCreateNewSession
channel: channel.Channel, message: ThpCreateNewSession
) -> ThpNewSession:
new_session_id: int = channel.create_new_session(message.passphrase)
session = SessionContext.create_new_session(channel)
session.set_session_state(SessionState.ALLOCATED)
channel.sessions[session.session_id] = session
loop.schedule(session.handle())
new_session_id: int = session.session_id
if __debug__:
log.debug(
__name__,
"create_new_session - new session created. Passphrase: %s, Session id: %d",
message.passphrase,
session.session_id,
)
print(channel.sessions)
return ThpNewSession(new_session_id=new_session_id)

@ -7,9 +7,10 @@ from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType, MessageType # , ThpPairingMethod
from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession
from trezor.messages import Failure
from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.handler_provider import get_handler
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto
@ -397,40 +398,12 @@ class Channel(Context):
if __debug__:
log.debug(__name__, "handle_channel_message: %s", message)
# TODO handle other messages than CreateNewSession
if TYPE_CHECKING:
assert isinstance(message, ThpCreateNewSession)
if __debug__:
log.debug(
__name__,
"handle_channel_message - passphrase: %s",
message.passphrase,
)
# await thp_messages.handle_CreateNewSession(message)
new_session_id: int = self.create_new_session(message.passphrase)
# TODO reuse existing buffer and compute size dynamically
bufferrone = bytearray(5)
msg = ThpNewSession(new_session_id=new_session_id)
message_size: int = thp_messages.get_new_session_message(
bufferrone, new_session_id
)
if __debug__:
log.debug(
__name__, "handle_channel_message - message size: %d", message_size
)
_encode_session_into_buffer(memoryview(bufferrone), 0)
if TYPE_CHECKING:
assert msg.MESSAGE_WIRE_TYPE is not None
_encode_message_type_into_buffer(
memoryview(bufferrone), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(bufferrone), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
await self.write(ThpNewSession(new_session_id=new_session_id))
# TODO not finished
handler = get_handler(message)
task = handler(self, message)
response_message = await task
# TODO handle
await self.write(response_message)
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
payload_buffer = bytearray(payload)
@ -600,28 +573,6 @@ class Channel(Context):
)
return protobuf.encoded_length(error_message)
def create_new_session(
self,
passphrase: str | None,
) -> int:
if __debug__:
log.debug(__name__, " create_new_session")
from trezor.wire.thp.session_context import SessionContext
session = SessionContext.create_new_session(self)
session.set_session_state(SessionState.ALLOCATED)
self.sessions[session.session_id] = session
loop.schedule(session.handle())
if __debug__:
log.debug(
__name__,
"create_new_session - new session created. Session id: %d",
session.session_id,
)
if __debug__:
print(self.sessions)
return session.session_id
def _todo_clear_buffer(self):
# TODO Buffer clearing not implemented
pass

@ -0,0 +1,16 @@
from typing import TYPE_CHECKING
from trezor import protobuf
from apps.thp import create_session
if TYPE_CHECKING:
from typing import Any, Callable, Coroutine
pass
def get_handler(
msg: protobuf.MessageType,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
return create_session.create_new_session

@ -8,7 +8,7 @@ from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
from ..protocol_common import Context, MessageWithType
from . import SessionState
from .channel import Channel
from . import channel
if TYPE_CHECKING:
from typing import Container # pyright: ignore[reportShadowedImports]
@ -29,7 +29,9 @@ class UnexpectedMessageWithType(Exception):
class SessionContext(Context):
def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None:
def __init__(
self, channel: channel.Channel, session_cache: SessionThpCache
) -> None:
if channel.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
@ -41,7 +43,7 @@ class SessionContext(Context):
self.incoming_message = loop.chan()
@classmethod
def create_new_session(cls, channel_context: Channel) -> "SessionContext":
def create_new_session(cls, channel_context: channel.Channel) -> "SessionContext":
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache)
@ -145,7 +147,7 @@ class SessionContext(Context):
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
def load_cached_sessions(channel: channel.Channel) -> dict[int, SessionContext]: # TODO
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {}

@ -2,7 +2,6 @@ import ustruct # pyright:ignore[reportMissingModuleSource]
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf
from trezor.messages import ThpCreateNewSession, ThpNewSession
from .. import message_handler
from ..protocol_common import Message
@ -98,21 +97,9 @@ def get_handshake_completion_response() -> bytes:
)
def get_new_session_message(buffer: bytearray, new_session_id: int) -> int:
msg = ThpNewSession(new_session_id=new_session_id)
encoded_msg = protobuf.encode(buffer, msg)
return encoded_msg
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:
print("decode message")
expected_type = protobuf.type_for_wire(msg_type)
x = message_handler.wrap_protobuf_load(buffer, expected_type)
print("result decoded", x)
return x
async def handle_CreateNewSession(msg: ThpCreateNewSession) -> None:
print(msg.passphrase)
print(msg.on_device)
pass

Loading…
Cancel
Save