Implement SessionContext write

M1nd3r/thp5
M1nd3r 1 month ago
parent 4a76216cbf
commit 117f72c689

@ -20,7 +20,7 @@ _CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(2)
_SESSION_ID_LENGTH = const(1)
SESSION_ID_LENGTH = const(1)
BROADCAST_CHANNEL_ID = const(65535)
KEY_LENGTH = const(32)
TAG_LENGTH = const(16)
@ -61,7 +61,7 @@ class ChannelCache(ConnectionCache):
class SessionThpCache(ConnectionCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.session_id = bytearray(SESSION_ID_LENGTH)
self.state = bytearray(_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
self.fields = (
@ -284,7 +284,7 @@ def get_next_session_id(channel: ChannelCache) -> bytes:
if _is_session_id_unique(channel):
break
new_sid = channel.session_id_counter
return new_sid.to_bytes(_SESSION_ID_LENGTH, "big")
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
def _is_session_id_unique(channel: ChannelCache) -> bool:
@ -307,10 +307,8 @@ def _get_cid(session: SessionThpCache) -> int:
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
if len(session_id) != _SESSION_ID_LENGTH:
raise ValueError(
"session_id must be X bytes long, where X=", _SESSION_ID_LENGTH
)
if len(session_id) != SESSION_ID_LENGTH:
raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH)
global _active_session_idx
global _is_active_session_authenticated
global _next_unauthenicated_session_index

@ -177,9 +177,15 @@ async def handle_session(
ctx.channel_id = msg.session_id
try:
next_msg = await message_handler.handle_single_message(
next_msg_without_id = await message_handler.handle_single_message(
ctx, msg, use_workflow=not is_debug_session
)
if next_msg_without_id is not None:
next_msg = protocol_common.MessageWithId(
next_msg_without_id.type,
next_msg_without_id.data,
bytearray(ctx.channel_id),
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.

@ -186,7 +186,7 @@ class CodecContext(Context):
)
CURRENT_CONTEXT: CodecContext | None = None
CURRENT_CONTEXT: Context | None = None
def wait(task: Awaitable[T]) -> Awaitable[T]:
@ -251,7 +251,7 @@ async def maybe_call(
await call(msg, expected_type)
def get_context() -> CodecContext:
def get_context() -> Context:
"""Get the current session context.
Can be needed in case the caller needs raw read and raw write capabilities, which
@ -265,7 +265,7 @@ def get_context() -> CodecContext:
return CURRENT_CONTEXT
def with_context(ctx: CodecContext, workflow: loop.Task) -> Generator:
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time

@ -63,8 +63,8 @@ if __debug__:
async def handle_single_message(
ctx: context.CodecContext, msg: protocol_common.MessageWithId, use_workflow: bool
) -> protocol_common.MessageWithId | None:
ctx: context.Context, msg: protocol_common.MessageWithType, use_workflow: bool
) -> protocol_common.MessageWithType | None:
"""Handle a message that was loaded from USB by the caller.
Find the appropriate handler, run it and write its result on the wire. In case

@ -4,7 +4,13 @@ from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
from typing import Container # pyright: ignore[reportShadowedImports]
from typing import ( # pyright: ignore[reportShadowedImports]
Container,
TypeVar,
overload,
)
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class Message:
@ -47,6 +53,18 @@ class Context:
self.iface: WireInterface = iface
self.channel_id: bytes = channel_id
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],

@ -5,10 +5,11 @@ from ubinascii import hexlify
import usb
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache
from trezor import loop, protobuf, utils
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import io, loop, protobuf, utils
from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler
from trezor.wire.thp import thp_messages
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum
@ -19,6 +20,7 @@ from .thp_messages import (
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
HANDSHAKE_INIT,
InitHeader,
)
from .thp_session import ThpError
@ -34,6 +36,7 @@ _PUBKEY_LENGTH = const(32)
INIT_DATA_OFFSET = const(5)
CONT_DATA_OFFSET = const(3)
MESSAGE_TYPE_LENGTH = const(2)
REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
@ -45,7 +48,7 @@ class Channel(Context):
super().__init__(iface, channel_cache.channel_id)
self.channel_cache = channel_cache
self.buffer: utils.BufferType
self.waiting_for_ack_timeout: loop.Task | None
self.waiting_for_ack_timeout: loop.spawn | None
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read = 0
@ -175,6 +178,9 @@ class Channel(Context):
)
# TODO send ack in response
# TODO send handshake init response message
await self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
self.set_channel_state(ChannelState.TH2)
return
@ -196,7 +202,7 @@ class Channel(Context):
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
print(message)
# ------------------------------------------------TYPE ERROR------------------------------------------------
# TODO handle other messages than CreateNewSession
assert isinstance(message, ThpCreateNewSession)
print("passphrase:", message.passphrase)
# await thp_messages.handle_CreateNewSession(message)
@ -262,10 +268,84 @@ class Channel(Context):
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
pass
# TODO protocol.write(self.iface, self.channel_id, session_id, msg)
# OTHER
noise_payload_len = self._encode_into_buffer(msg, session_id)
# trezor.crypto.noise.encode(key, payload=self.buffer)
# TODO payload_len should be output from trezor.crypto.noise.encode
payload_len = noise_payload_len # + TAG_LENGTH # TODO
await self._write_encrypted_payload_loop(self.buffer[:payload_len])
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
payload_len = len(payload)
header = InitHeader(
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
)
while True:
print("write encrypted payload loop - start")
await self._write_encrypted_payload(header, payload, payload_len)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try:
await self.waiting_for_ack_timeout
except loop.TaskClosed:
break
async def _write_encrypted_payload(
self, header: InitHeader, payload: bytes, payload_len: int
):
# prepare the report buffer with header data
report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await self._write_report(report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_len:
header.pack_to_cont_buffer(report)
while nwritten < payload_len:
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await self._write_report(report)
async def _write_report(self, report: utils.BufferType) -> None:
while True:
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
n = self.iface.write(report)
if n == len(report):
return
async def _wait_for_ack(self) -> None:
await loop.sleep(1000)
# TODO retry write
def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
payload_size = offset + msg_size
if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray):
# message is too big or buffer is not bytearray, we need to allocate a new buffer
self.buffer = bytearray(payload_size)
buffer = self.buffer
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
msg_type_bytes = int.to_bytes(msg.MESSAGE_WIRE_TYPE, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, 0, session_id_bytes, 0)
utils.memcpy(buffer, SESSION_ID_LENGTH, msg_type_bytes, 0)
assert isinstance(buffer, bytearray)
msg_size = protobuf.encode(buffer[offset:], msg)
return payload_size
def create_new_session(
self,

@ -2,8 +2,8 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage import cache_thp
from storage.cache_thp import SessionThpCache
from trezor import loop, protobuf
from trezor.wire import message_handler
from trezor import log, loop, protobuf
from trezor.wire import AVOID_RESTARTING_FOR, failure, message_handler, protocol_common
from ..protocol_common import Context, MessageWithType
from . import SessionState
@ -44,12 +44,66 @@ class SessionContext(Context):
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache)
async def handle(self) -> None:
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__ and is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take()
next_message: MessageWithType | None = None
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
# TODO modules = utils.unimport_begin()
while True:
message = await take
print(message)
# TODO continue similarly to handle_session function in wire.__init__
try:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: MessageWithType = await take
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await message_handler.handle_single_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if (
next_message is None
and message.type not in AVOID_RESTARTING_FOR
):
# Shut down the loop if there is no next message waiting.
# Let the session be restarted from `main`.
loop.clear()
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
async def read(
self,

@ -19,7 +19,7 @@ _CHANNEL_ALLOCATION_RES = 0x40
class InitHeader:
format_str = ">BHH"
def __init__(self, ctrl_byte, cid, length) -> None:
def __init__(self, ctrl_byte, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.length = length
@ -79,7 +79,7 @@ def get_error_unallocated_channel() -> bytes:
def get_handshake_init_response() -> bytes:
return b"\x00" # TODO implement
return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:

@ -222,7 +222,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH)
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY))
# but resuming a session loads the previous one

Loading…
Cancel
Save