Implement SessionContext write

M1nd3r/thp5
M1nd3r 2 months ago
parent 4a76216cbf
commit 117f72c689

@ -20,7 +20,7 @@ _CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1) _WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1) _SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(2) _CHANNEL_ID_LENGTH = const(2)
_SESSION_ID_LENGTH = const(1) SESSION_ID_LENGTH = const(1)
BROADCAST_CHANNEL_ID = const(65535) BROADCAST_CHANNEL_ID = const(65535)
KEY_LENGTH = const(32) KEY_LENGTH = const(32)
TAG_LENGTH = const(16) TAG_LENGTH = const(16)
@ -61,7 +61,7 @@ class ChannelCache(ConnectionCache):
class SessionThpCache(ConnectionCache): class SessionThpCache(ConnectionCache):
def __init__(self) -> None: def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH) self.session_id = bytearray(SESSION_ID_LENGTH)
self.state = bytearray(_SESSION_STATE_LENGTH) self.state = bytearray(_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY: if utils.BITCOIN_ONLY:
self.fields = ( self.fields = (
@ -284,7 +284,7 @@ def get_next_session_id(channel: ChannelCache) -> bytes:
if _is_session_id_unique(channel): if _is_session_id_unique(channel):
break break
new_sid = channel.session_id_counter new_sid = channel.session_id_counter
return new_sid.to_bytes(_SESSION_ID_LENGTH, "big") return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
def _is_session_id_unique(channel: ChannelCache) -> bool: def _is_session_id_unique(channel: ChannelCache) -> bool:
@ -307,10 +307,8 @@ def _get_cid(session: SessionThpCache) -> int:
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache: def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
if len(session_id) != _SESSION_ID_LENGTH: if len(session_id) != SESSION_ID_LENGTH:
raise ValueError( raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH)
"session_id must be X bytes long, where X=", _SESSION_ID_LENGTH
)
global _active_session_idx global _active_session_idx
global _is_active_session_authenticated global _is_active_session_authenticated
global _next_unauthenicated_session_index global _next_unauthenicated_session_index

@ -177,9 +177,15 @@ async def handle_session(
ctx.channel_id = msg.session_id ctx.channel_id = msg.session_id
try: try:
next_msg = await message_handler.handle_single_message( next_msg_without_id = await message_handler.handle_single_message(
ctx, msg, use_workflow=not is_debug_session ctx, msg, use_workflow=not is_debug_session
) )
if next_msg_without_id is not None:
next_msg = protocol_common.MessageWithId(
next_msg_without_id.type,
next_msg_without_id.data,
bytearray(ctx.channel_id),
)
except Exception as exc: except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the # Log and ignore. The session handler can only exit explicitly in the
# following finally block. # following finally block.

@ -186,7 +186,7 @@ class CodecContext(Context):
) )
CURRENT_CONTEXT: CodecContext | None = None CURRENT_CONTEXT: Context | None = None
def wait(task: Awaitable[T]) -> Awaitable[T]: def wait(task: Awaitable[T]) -> Awaitable[T]:
@ -251,7 +251,7 @@ async def maybe_call(
await call(msg, expected_type) await call(msg, expected_type)
def get_context() -> CodecContext: def get_context() -> Context:
"""Get the current session context. """Get the current session context.
Can be needed in case the caller needs raw read and raw write capabilities, which Can be needed in case the caller needs raw read and raw write capabilities, which
@ -265,7 +265,7 @@ def get_context() -> CodecContext:
return CURRENT_CONTEXT return CURRENT_CONTEXT
def with_context(ctx: CodecContext, workflow: loop.Task) -> Generator: def with_context(ctx: Context, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context. """Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time Stores the context in a closure and installs it into the global variable every time

@ -63,8 +63,8 @@ if __debug__:
async def handle_single_message( async def handle_single_message(
ctx: context.CodecContext, msg: protocol_common.MessageWithId, use_workflow: bool ctx: context.Context, msg: protocol_common.MessageWithType, use_workflow: bool
) -> protocol_common.MessageWithId | None: ) -> protocol_common.MessageWithType | None:
"""Handle a message that was loaded from USB by the caller. """Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case Find the appropriate handler, run it and write its result on the wire. In case

@ -4,7 +4,13 @@ from trezor import protobuf
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports] from trezorio import WireInterface # pyright: ignore[reportMissingImports]
from typing import Container # pyright: ignore[reportShadowedImports] from typing import ( # pyright: ignore[reportShadowedImports]
Container,
TypeVar,
overload,
)
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class Message: class Message:
@ -47,6 +53,18 @@ class Context:
self.iface: WireInterface = iface self.iface: WireInterface = iface
self.channel_id: bytes = channel_id self.channel_id: bytes = channel_id
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read( async def read(
self, self,
expected_types: Container[int], expected_types: Container[int],

@ -5,10 +5,11 @@ from ubinascii import hexlify
import usb import usb
from storage import cache_thp from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import loop, protobuf, utils from trezor import io, loop, protobuf, utils
from trezor.messages import ThpCreateNewSession from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler from trezor.wire import message_handler
from trezor.wire.thp import thp_messages
from ..protocol_common import Context, MessageWithType from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum from . import ChannelState, SessionState, checksum
@ -19,6 +20,7 @@ from .thp_messages import (
CONTINUATION_PACKET, CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT, ENCRYPTED_TRANSPORT,
HANDSHAKE_INIT, HANDSHAKE_INIT,
InitHeader,
) )
from .thp_session import ThpError from .thp_session import ThpError
@ -34,6 +36,7 @@ _PUBKEY_LENGTH = const(32)
INIT_DATA_OFFSET = const(5) INIT_DATA_OFFSET = const(5)
CONT_DATA_OFFSET = const(3) CONT_DATA_OFFSET = const(3)
MESSAGE_TYPE_LENGTH = const(2)
REPORT_LENGTH = const(64) REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000) MAX_PAYLOAD_LEN = const(60000)
@ -45,7 +48,7 @@ class Channel(Context):
super().__init__(iface, channel_cache.channel_id) super().__init__(iface, channel_cache.channel_id)
self.channel_cache = channel_cache self.channel_cache = channel_cache
self.buffer: utils.BufferType self.buffer: utils.BufferType
self.waiting_for_ack_timeout: loop.Task | None self.waiting_for_ack_timeout: loop.spawn | None
self.is_cont_packet_expected: bool = False self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0 self.expected_payload_length: int = 0
self.bytes_read = 0 self.bytes_read = 0
@ -175,6 +178,9 @@ class Channel(Context):
) )
# TODO send ack in response # TODO send ack in response
# TODO send handshake init response message # TODO send handshake init response message
await self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
self.set_channel_state(ChannelState.TH2) self.set_channel_state(ChannelState.TH2)
return return
@ -196,7 +202,7 @@ class Channel(Context):
expected_type = protobuf.type_for_wire(message_type) expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type) message = message_handler.wrap_protobuf_load(buf, expected_type)
print(message) print(message)
# ------------------------------------------------TYPE ERROR------------------------------------------------ # TODO handle other messages than CreateNewSession
assert isinstance(message, ThpCreateNewSession) assert isinstance(message, ThpCreateNewSession)
print("passphrase:", message.passphrase) print("passphrase:", message.passphrase)
# await thp_messages.handle_CreateNewSession(message) # await thp_messages.handle_CreateNewSession(message)
@ -262,10 +268,84 @@ class Channel(Context):
# CALLED BY WORKFLOW / SESSION CONTEXT # CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
pass
# TODO protocol.write(self.iface, self.channel_id, session_id, msg)
# OTHER noise_payload_len = self._encode_into_buffer(msg, session_id)
# trezor.crypto.noise.encode(key, payload=self.buffer)
# TODO payload_len should be output from trezor.crypto.noise.encode
payload_len = noise_payload_len # + TAG_LENGTH # TODO
await self._write_encrypted_payload_loop(self.buffer[:payload_len])
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
payload_len = len(payload)
header = InitHeader(
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
)
while True:
print("write encrypted payload loop - start")
await self._write_encrypted_payload(header, payload, payload_len)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try:
await self.waiting_for_ack_timeout
except loop.TaskClosed:
break
async def _write_encrypted_payload(
self, header: InitHeader, payload: bytes, payload_len: int
):
# prepare the report buffer with header data
report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await self._write_report(report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_len:
header.pack_to_cont_buffer(report)
while nwritten < payload_len:
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await self._write_report(report)
async def _write_report(self, report: utils.BufferType) -> None:
while True:
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
n = self.iface.write(report)
if n == len(report):
return
async def _wait_for_ack(self) -> None:
await loop.sleep(1000)
# TODO retry write
def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
payload_size = offset + msg_size
if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray):
# message is too big or buffer is not bytearray, we need to allocate a new buffer
self.buffer = bytearray(payload_size)
buffer = self.buffer
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
msg_type_bytes = int.to_bytes(msg.MESSAGE_WIRE_TYPE, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, 0, session_id_bytes, 0)
utils.memcpy(buffer, SESSION_ID_LENGTH, msg_type_bytes, 0)
assert isinstance(buffer, bytearray)
msg_size = protobuf.encode(buffer[offset:], msg)
return payload_size
def create_new_session( def create_new_session(
self, self,

@ -2,8 +2,8 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage import cache_thp from storage import cache_thp
from storage.cache_thp import SessionThpCache from storage.cache_thp import SessionThpCache
from trezor import loop, protobuf from trezor import log, loop, protobuf
from trezor.wire import message_handler from trezor.wire import AVOID_RESTARTING_FOR, failure, message_handler, protocol_common
from ..protocol_common import Context, MessageWithType from ..protocol_common import Context, MessageWithType
from . import SessionState from . import SessionState
@ -44,12 +44,66 @@ class SessionContext(Context):
session_cache = cache_thp.get_new_session(channel_context.channel_cache) session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache) return cls(channel_context, session_cache)
async def handle(self) -> None: async def handle(self, is_debug_session: bool = False) -> None:
if __debug__ and is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take() take = self.incoming_message.take()
next_message: MessageWithType | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
# TODO modules = utils.unimport_begin()
while True: while True:
message = await take try:
print(message) if next_message is None:
# TODO continue similarly to handle_session function in wire.__init__ # If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: MessageWithType = await take
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await message_handler.handle_single_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if (
next_message is None
and message.type not in AVOID_RESTARTING_FOR
):
# Shut down the loop if there is no next message waiting.
# Let the session be restarted from `main`.
loop.clear()
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
async def read( async def read(
self, self,

@ -19,7 +19,7 @@ _CHANNEL_ALLOCATION_RES = 0x40
class InitHeader: class InitHeader:
format_str = ">BHH" format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None: def __init__(self, ctrl_byte, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte self.ctrl_byte = ctrl_byte
self.cid = cid self.cid = cid
self.length = length self.length = length
@ -79,7 +79,7 @@ def get_error_unallocated_channel() -> bytes:
def get_handshake_init_response() -> bytes: def get_handshake_init_response() -> bytes:
return b"\x00" # TODO implement return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:

@ -222,7 +222,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache # supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH) call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY)) self.assertIsNone(cache.get(KEY))
# but resuming a session loads the previous one # but resuming a session loads the previous one

Loading…
Cancel
Save