mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): implement SessionContext write
This commit is contained in:
parent
6ab6d2c109
commit
555708493b
@ -20,7 +20,7 @@ _CHANNEL_STATE_LENGTH = const(1)
|
||||
_WIRE_INTERFACE_LENGTH = const(1)
|
||||
_SESSION_STATE_LENGTH = const(1)
|
||||
_CHANNEL_ID_LENGTH = const(2)
|
||||
_SESSION_ID_LENGTH = const(1)
|
||||
SESSION_ID_LENGTH = const(1)
|
||||
BROADCAST_CHANNEL_ID = const(65535)
|
||||
KEY_LENGTH = const(32)
|
||||
TAG_LENGTH = const(16)
|
||||
@ -61,7 +61,7 @@ class ChannelCache(ConnectionCache):
|
||||
|
||||
class SessionThpCache(ConnectionCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
self.session_id = bytearray(SESSION_ID_LENGTH)
|
||||
self.state = bytearray(_SESSION_STATE_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
@ -284,7 +284,7 @@ def get_next_session_id(channel: ChannelCache) -> bytes:
|
||||
if _is_session_id_unique(channel):
|
||||
break
|
||||
new_sid = channel.session_id_counter
|
||||
return new_sid.to_bytes(_SESSION_ID_LENGTH, "big")
|
||||
return new_sid.to_bytes(SESSION_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def _is_session_id_unique(channel: ChannelCache) -> bool:
|
||||
@ -307,10 +307,8 @@ def _get_cid(session: SessionThpCache) -> int:
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
|
||||
if len(session_id) != _SESSION_ID_LENGTH:
|
||||
raise ValueError(
|
||||
"session_id must be X bytes long, where X=", _SESSION_ID_LENGTH
|
||||
)
|
||||
if len(session_id) != SESSION_ID_LENGTH:
|
||||
raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH)
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
global _next_unauthenicated_session_index
|
||||
|
@ -177,9 +177,15 @@ async def handle_session(
|
||||
ctx.channel_id = msg.session_id
|
||||
|
||||
try:
|
||||
next_msg = await message_handler.handle_single_message(
|
||||
next_msg_without_id = await message_handler.handle_single_message(
|
||||
ctx, msg, use_workflow=not is_debug_session
|
||||
)
|
||||
if next_msg_without_id is not None:
|
||||
next_msg = protocol_common.MessageWithId(
|
||||
next_msg_without_id.type,
|
||||
next_msg_without_id.data,
|
||||
bytearray(ctx.channel_id),
|
||||
)
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
|
@ -186,7 +186,7 @@ class CodecContext(Context):
|
||||
)
|
||||
|
||||
|
||||
CURRENT_CONTEXT: CodecContext | None = None
|
||||
CURRENT_CONTEXT: Context | None = None
|
||||
|
||||
|
||||
def wait(task: Awaitable[T]) -> Awaitable[T]:
|
||||
@ -251,7 +251,7 @@ async def maybe_call(
|
||||
await call(msg, expected_type)
|
||||
|
||||
|
||||
def get_context() -> CodecContext:
|
||||
def get_context() -> Context:
|
||||
"""Get the current session context.
|
||||
|
||||
Can be needed in case the caller needs raw read and raw write capabilities, which
|
||||
@ -265,7 +265,7 @@ def get_context() -> CodecContext:
|
||||
return CURRENT_CONTEXT
|
||||
|
||||
|
||||
def with_context(ctx: CodecContext, workflow: loop.Task) -> Generator:
|
||||
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
||||
"""Run a workflow in a particular context.
|
||||
|
||||
Stores the context in a closure and installs it into the global variable every time
|
||||
|
@ -63,8 +63,8 @@ if __debug__:
|
||||
|
||||
|
||||
async def handle_single_message(
|
||||
ctx: context.CodecContext, msg: protocol_common.MessageWithId, use_workflow: bool
|
||||
) -> protocol_common.MessageWithId | None:
|
||||
ctx: context.Context, msg: protocol_common.MessageWithType, use_workflow: bool
|
||||
) -> protocol_common.MessageWithType | None:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
|
@ -4,8 +4,13 @@ from trezor import protobuf
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from trezorio import WireInterface
|
||||
from typing import Container
|
||||
from typing import (
|
||||
Container,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
class Message:
|
||||
@ -48,6 +53,18 @@ class Context:
|
||||
self.iface: WireInterface = iface
|
||||
self.channel_id: bytes = channel_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int]
|
||||
) -> protobuf.MessageType: ...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType: ...
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
|
@ -5,10 +5,11 @@ from ubinascii import hexlify
|
||||
|
||||
import usb
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache
|
||||
from trezor import loop, protobuf, utils
|
||||
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
|
||||
from trezor import io, loop, protobuf, utils
|
||||
from trezor.messages import ThpCreateNewSession
|
||||
from trezor.wire import message_handler
|
||||
from trezor.wire.thp import thp_messages
|
||||
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
from . import ChannelState, SessionState, checksum
|
||||
@ -19,6 +20,7 @@ from .thp_messages import (
|
||||
CONTINUATION_PACKET,
|
||||
ENCRYPTED_TRANSPORT,
|
||||
HANDSHAKE_INIT,
|
||||
InitHeader,
|
||||
)
|
||||
from .thp_session import ThpError
|
||||
|
||||
@ -34,6 +36,7 @@ _PUBKEY_LENGTH = const(32)
|
||||
INIT_DATA_OFFSET = const(5)
|
||||
CONT_DATA_OFFSET = const(3)
|
||||
|
||||
MESSAGE_TYPE_LENGTH = const(2)
|
||||
|
||||
REPORT_LENGTH = const(64)
|
||||
MAX_PAYLOAD_LEN = const(60000)
|
||||
@ -45,7 +48,7 @@ class Channel(Context):
|
||||
super().__init__(iface, channel_cache.channel_id)
|
||||
self.channel_cache = channel_cache
|
||||
self.buffer: utils.BufferType
|
||||
self.waiting_for_ack_timeout: loop.Task | None
|
||||
self.waiting_for_ack_timeout: loop.spawn | None
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.expected_payload_length: int = 0
|
||||
self.bytes_read = 0
|
||||
@ -175,6 +178,9 @@ class Channel(Context):
|
||||
)
|
||||
# TODO send ack in response
|
||||
# TODO send handshake init response message
|
||||
await self._write_encrypted_payload_loop(
|
||||
thp_messages.get_handshake_init_response()
|
||||
)
|
||||
self.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
|
||||
@ -196,7 +202,7 @@ class Channel(Context):
|
||||
expected_type = protobuf.type_for_wire(message_type)
|
||||
message = message_handler.wrap_protobuf_load(buf, expected_type)
|
||||
print(message)
|
||||
# ------------------------------------------------TYPE ERROR------------------------------------------------
|
||||
# TODO handle other messages than CreateNewSession
|
||||
assert isinstance(message, ThpCreateNewSession)
|
||||
print("passphrase:", message.passphrase)
|
||||
# await thp_messages.handle_CreateNewSession(message)
|
||||
@ -262,10 +268,84 @@ class Channel(Context):
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
pass
|
||||
# TODO protocol.write(self.iface, self.channel_id, session_id, msg)
|
||||
|
||||
# OTHER
|
||||
noise_payload_len = self._encode_into_buffer(msg, session_id)
|
||||
|
||||
# trezor.crypto.noise.encode(key, payload=self.buffer)
|
||||
|
||||
# TODO payload_len should be output from trezor.crypto.noise.encode
|
||||
payload_len = noise_payload_len # + TAG_LENGTH # TODO
|
||||
|
||||
await self._write_encrypted_payload_loop(self.buffer[:payload_len])
|
||||
|
||||
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
||||
|
||||
payload_len = len(payload)
|
||||
header = InitHeader(
|
||||
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
|
||||
)
|
||||
|
||||
while True:
|
||||
print("write encrypted payload loop - start")
|
||||
await self._write_encrypted_payload(header, payload, payload_len)
|
||||
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
||||
try:
|
||||
await self.waiting_for_ack_timeout
|
||||
except loop.TaskClosed:
|
||||
break
|
||||
|
||||
async def _write_encrypted_payload(
|
||||
self, header: InitHeader, payload: bytes, payload_len: int
|
||||
):
|
||||
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
await self._write_report(report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_len:
|
||||
header.pack_to_cont_buffer(report)
|
||||
while nwritten < payload_len:
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await self._write_report(report)
|
||||
|
||||
async def _write_report(self, report: utils.BufferType) -> None:
|
||||
while True:
|
||||
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
||||
n = self.iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
||||
|
||||
async def _wait_for_ack(self) -> None:
|
||||
await loop.sleep(1000)
|
||||
# TODO retry write
|
||||
|
||||
def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int:
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
||||
payload_size = offset + msg_size
|
||||
|
||||
if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray):
|
||||
# message is too big or buffer is not bytearray, we need to allocate a new buffer
|
||||
self.buffer = bytearray(payload_size)
|
||||
|
||||
buffer = self.buffer
|
||||
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
|
||||
msg_type_bytes = int.to_bytes(msg.MESSAGE_WIRE_TYPE, MESSAGE_TYPE_LENGTH, "big")
|
||||
|
||||
utils.memcpy(buffer, 0, session_id_bytes, 0)
|
||||
utils.memcpy(buffer, SESSION_ID_LENGTH, msg_type_bytes, 0)
|
||||
assert isinstance(buffer, bytearray)
|
||||
msg_size = protobuf.encode(buffer[offset:], msg)
|
||||
return payload_size
|
||||
|
||||
def create_new_session(
|
||||
self,
|
||||
|
@ -2,8 +2,8 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import loop, protobuf
|
||||
from trezor.wire import message_handler
|
||||
from trezor import log, loop, protobuf
|
||||
from trezor.wire import AVOID_RESTARTING_FOR, failure, message_handler, protocol_common
|
||||
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
from . import SessionState
|
||||
@ -44,12 +44,66 @@ class SessionContext(Context):
|
||||
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
|
||||
return cls(channel_context, session_cache)
|
||||
|
||||
async def handle(self) -> None:
|
||||
async def handle(self, is_debug_session: bool = False) -> None:
|
||||
if __debug__ and is_debug_session:
|
||||
import apps.debug
|
||||
|
||||
apps.debug.DEBUG_CONTEXT = self
|
||||
|
||||
take = self.incoming_message.take()
|
||||
next_message: MessageWithType | None = None
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
# roll back and un-import any others.
|
||||
# TODO modules = utils.unimport_begin()
|
||||
while True:
|
||||
message = await take
|
||||
print(message)
|
||||
# TODO continue similarly to handle_session function in wire.__init__
|
||||
try:
|
||||
if next_message is None:
|
||||
# If the previous run did not keep an unprocessed message for us,
|
||||
# wait for a new one.
|
||||
try:
|
||||
message: MessageWithType = await take
|
||||
except protocol_common.WireError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
await self.write(failure(e))
|
||||
continue
|
||||
else:
|
||||
# Process the message from previous run.
|
||||
message = next_message
|
||||
next_message = None
|
||||
|
||||
try:
|
||||
next_message = await message_handler.handle_single_message(
|
||||
self, message, use_workflow=not is_debug_session
|
||||
)
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
finally:
|
||||
if not __debug__ or not is_debug_session:
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
# This is not done for the debug session because the snapshot taken
|
||||
# in a debug session would clear modules which are in use by the
|
||||
# workflow running on wire.
|
||||
# TODO utils.unimport_end(modules)
|
||||
|
||||
if (
|
||||
next_message is None
|
||||
and message.type not in AVOID_RESTARTING_FOR
|
||||
):
|
||||
# Shut down the loop if there is no next message waiting.
|
||||
# Let the session be restarted from `main`.
|
||||
loop.clear()
|
||||
return # pylint: disable=lost-exception
|
||||
|
||||
except Exception as exc:
|
||||
# Log and try again. The session handler can only exit explicitly via
|
||||
# loop.clear() above.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
|
@ -19,7 +19,7 @@ _CHANNEL_ALLOCATION_RES = 0x40
|
||||
class InitHeader:
|
||||
format_str = ">BHH"
|
||||
|
||||
def __init__(self, ctrl_byte, cid, length) -> None:
|
||||
def __init__(self, ctrl_byte, cid: int, length: int) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.length = length
|
||||
@ -79,7 +79,7 @@ def get_error_unallocated_channel() -> bytes:
|
||||
|
||||
|
||||
def get_handshake_init_response() -> bytes:
|
||||
return b"\x00" # TODO implement
|
||||
return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement
|
||||
|
||||
|
||||
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:
|
||||
|
@ -222,7 +222,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
|
||||
# supplying a different session ID starts a new cache
|
||||
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE._SESSION_ID_LENGTH)
|
||||
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
# but resuming a session loads the previous one
|
||||
|
Loading…
Reference in New Issue
Block a user