1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): improve channels, sessions and handshake

This commit is contained in:
M1nd3r 2024-04-27 02:21:30 +02:00
parent 42873b1c30
commit 7c447ac5d1
12 changed files with 495 additions and 126 deletions

View File

@ -19,10 +19,11 @@ _MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove
_CHANNEL_STATE_LENGTH = const(1)
_WIRE_INTERFACE_LENGTH = const(1)
_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(4)
_SESSION_ID_LENGTH = const(4)
_CHANNEL_ID_LENGTH = const(2)
_SESSION_ID_LENGTH = const(1)
BROADCAST_CHANNEL_ID = const(65535)
KEY_LENGTH = const(32)
TAG_LENGTH = const(16)
_UNALLOCATED_STATE = const(0)
@ -40,10 +41,14 @@ class ConnectionCache(DataCache):
class ChannelCache(ConnectionCache):
def __init__(self) -> None:
self.enc_key = 0 # TODO change
self.dec_key = 1 # TODO change
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
self.enc_key = bytearray(KEY_LENGTH)
self.dec_key = bytearray(KEY_LENGTH)
self.state = bytearray(_CHANNEL_STATE_LENGTH)
self.iface = bytearray(1) # TODO add decoding
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
self.session_id_counter = 0x01
self.fields = ()
super().__init__()
def clear(self) -> None:
@ -135,13 +140,20 @@ def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache:
new_cid = get_next_channel_id()
index = _get_next_unauthenticated_channel_index()
# clear sessions from replaced channel
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
old_cid = _CHANNELS[index].channel_id
for session in _SESSIONS:
if session.channel_id == old_cid:
session.clear()
_CHANNELS[index] = ChannelCache()
_CHANNELS[index].channel_id[:] = new_cid
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
_CHANNELS[index].state = bytearray(
_CHANNELS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
)
_CHANNELS[index].iface = bytearray(iface)
_CHANNELS[index].iface[:] = bytearray(iface)
return _CHANNELS[index]
@ -153,6 +165,38 @@ def get_all_allocated_channels() -> list[ChannelCache]:
return _list
def get_all_allocated_sessions() -> list[SessionThpCache]:
_list: list[SessionThpCache] = []
for session in _SESSIONS:
if _get_session_state(session) != _UNALLOCATED_STATE:
_list.append(session)
return _list
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
if len(key) != KEY_LENGTH:
raise Exception("Invalid key length")
channel.host_ephemeral_pubkey = key
def get_new_session(channel: ChannelCache):
new_sid = get_next_session_id(channel)
index = _get_next_session_index()
_SESSIONS[index] = SessionThpCache()
_SESSIONS[index].channel_id[:] = channel.channel_id
_SESSIONS[index].session_id[:] = new_sid
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
channel.last_usage = (
_get_usage_counter_and_increment()
) # increment also use of the channel so it does not get replaced
_SESSIONS[index].state[:] = bytearray(
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
)
return _SESSIONS[index]
def _get_usage_counter() -> int:
global _usage_counter
return _usage_counter
@ -171,6 +215,13 @@ def _get_next_unauthenticated_channel_index() -> int:
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
def _get_next_session_index() -> int:
idx = _get_unallocated_session_index()
if idx is not None:
return idx
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
def _get_unallocated_channel_index() -> int | None:
for i in range(_MAX_CHANNELS_COUNT):
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
@ -178,12 +229,25 @@ def _get_unallocated_channel_index() -> int | None:
return None
def _get_unallocated_session_index() -> int | None:
for i in range(_MAX_SESSIONS_COUNT):
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
return i
return None
def _get_channel_state(channel: ChannelCache) -> int:
if channel is None:
return _UNALLOCATED_STATE
return int.from_bytes(channel.state, "big")
def _get_session_state(session: SessionThpCache) -> int:
if session is None:
return _UNALLOCATED_STATE
return int.from_bytes(session.state, "big")
def get_active_session_id() -> bytearray | None:
active_session = get_active_session()
@ -200,9 +264,6 @@ def get_active_session() -> SessionThpCache | None:
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
_session_usage_counter = 0
def get_next_channel_id() -> bytes:
global cid_counter
while True:
@ -214,6 +275,25 @@ def get_next_channel_id() -> bytes:
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
def get_next_session_id(channel: ChannelCache) -> bytes:
while not _is_session_id_unique(channel):
if channel.session_id_counter >= 255:
channel.session_id_counter = 1
else:
channel.session_id_counter += 1
new_sid = channel.session_id_counter
channel.session_id_counter += 1
return new_sid.to_bytes(_SESSION_ID_LENGTH, "big")
def _is_session_id_unique(channel: ChannelCache) -> bool:
for session in _SESSIONS:
if session.channel_id == channel.channel_id:
if session.session_id == channel.session_id_counter:
return False
return True
def _is_cid_unique() -> bool:
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
if cid_counter == _get_cid(session):
@ -226,8 +306,10 @@ def _get_cid(session: SessionThpCache) -> int:
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
if len(session_id) != 4:
raise ValueError("session_id must be 4 bytes long.")
if len(session_id) != _SESSION_ID_LENGTH:
raise ValueError(
"session_id must be X bytes long, where X=", _SESSION_ID_LENGTH
)
global _active_session_idx
global _is_active_session_authenticated
global _next_unauthenicated_session_index
@ -302,7 +384,10 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
new_session_id = b"\x00\x00" + get_next_channel_id()
channel = get_new_unauthenticated_channel(b"\x00")
new_session_id = get_next_session_id(channel)
new_session = create_new_unauthenticated_session(new_session_id)

View File

@ -40,7 +40,13 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Container, Coroutine, TypeVar
from typing import (
Any,
Callable,
Container,
Coroutine,
TypeVar,
)
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
@ -53,10 +59,11 @@ EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(
handle_session(iface, codec_v1.SESSION_ID.to_bytes(4, "big"), is_debug_session)
)
"""Initialize the wire stack on passed WireInterface."""
if utils.USE_THP:
loop.schedule(handle_thp_session(iface, is_debug_session))
else:
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
def wrap_protobuf_load(
@ -128,13 +135,13 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
async def handle_session(
iface: WireInterface, session_id: bytes, is_debug_session: bool = False
iface: WireInterface, codec_session_id: int, is_debug_session: bool = False
) -> None:
if __debug__ and is_debug_session:
ctx_buffer = WIRE_BUFFER_DEBUG
else:
ctx_buffer = WIRE_BUFFER
session_id = codec_session_id.to_bytes(4, "big")
ctx = context.CodecContext(iface, ctx_buffer, session_id)
next_msg: protocol_common.MessageWithId | None = None

View File

@ -15,7 +15,13 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Any, Callable, Container, Coroutine, TypeVar
from typing import (
Any,
Callable,
Container,
Coroutine,
TypeVar,
)
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]

View File

@ -41,13 +41,13 @@ class MessageWithId(MessageWithType):
super().__init__(message_type, message_data)
class WireError(Exception):
pass
class Context:
def __init__(self, iface: WireInterface, channel_id: bytes) -> None:
self.iface: WireInterface = iface
self.channel_id: bytes = channel_id
async def write(self, msg: protobuf.MessageType) -> None: ...
class WireError(Exception):
pass

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING:
from enum import IntEnum
@ -19,6 +19,10 @@ class ChannelState(IntEnum):
ENCRYPTED_TRANSPORT = 9
class SessionState(IntEnum):
UNALLOCATED = 0
class WireInterfaceType(IntEnum):
MOCK = 0
USB = 1

View File

@ -4,26 +4,28 @@ from trezor import log
from . import thp_session as THP
def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
def handle_received_ACK(cache: SessionThpCache, sync_bit: int) -> None:
if _ack_is_not_expected(session):
if __debug__:
log.debug(__name__, "Received unexpected ACK message")
if _ack_is_not_expected(cache):
_conditionally_log_debug(__name__, "Received unexpected ACK message")
return
if _ack_has_incorrect_sync_bit(session, sync_bit):
if __debug__:
log.debug(__name__, "Received ACK message with wrong sync bit")
if _ack_has_incorrect_sync_bit(cache, sync_bit):
_conditionally_log_debug(__name__, "Received ACK message with wrong sync bit")
return
# ACK is expected and it has correct sync bit
_conditionally_log_debug(__name__, "Received ACK message with correct sync bit")
THP.sync_set_can_send_message(cache, True)
def _ack_is_not_expected(cache: SessionThpCache) -> bool:
return THP.sync_can_send_message(cache)
def _ack_has_incorrect_sync_bit(cache: SessionThpCache, sync_bit: int) -> bool:
return THP.sync_get_send_bit(cache) != sync_bit
def _conditionally_log_debug(name, message):
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
THP.sync_set_can_send_message(session, True)
def _ack_is_not_expected(session: SessionThpCache) -> bool:
return THP.sync_can_send_message(session)
def _ack_has_incorrect_sync_bit(session: SessionThpCache, sync_bit: int) -> bool:
return THP.sync_get_send_bit(session) != sync_bit
log.debug(name, message)

View File

@ -1,26 +1,53 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import ( # pyright:ignore[reportShadowedImports]
TYPE_CHECKING,
Any,
Callable,
Coroutine,
)
import usb
from storage import cache_thp
from storage.cache_thp import ChannelCache
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache
from trezor import loop, protobuf, utils
from ..protocol_common import Context
from . import ChannelState, SessionState, checksum
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
# from . import thp_session
from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT
from .thp_messages import (
ACK_MESSAGE,
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
HANDSHAKE_INIT,
)
from .thp_session import ThpError
# from .thp_session import SessionState, ThpError
if TYPE_CHECKING:
from trezorio import WireInterface
from trezorio import WireInterface # type:ignore
Handler = Callable[
[bytes, Any, Any, Any], Coroutine
] # TODO Adjust parameters to be more restrictive
_INIT_DATA_OFFSET = const(5)
_CONT_DATA_OFFSET = const(3)
_INIT_DATA_OFFSET = const(5)
_REPORT_CONT_DATA_OFFSET = const(3)
_WIRE_INTERFACE_USB = b"\x00"
_WIRE_INTERFACE_USB = b"\x01"
_MOCK_INTERFACE_HID = b"\x00"
_PUBKEY_LENGTH = const(32)
_REPORT_LENGTH = const(64)
_MAX_PAYLOAD_LEN = const(60000)
class ChannelContext(Context):
@ -33,6 +60,9 @@ class ChannelContext(Context):
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read = 0
from trezor.wire.thp.session_context import load_cached_sessions
self.sessions = load_cached_sessions(self)
@classmethod
def create_new_channel(cls, iface: WireInterface) -> "ChannelContext":
@ -41,9 +71,12 @@ class ChannelContext(Context):
# ACCESS TO CHANNEL_DATA
def get_management_session_state(self): # TODO redo for channel state
# return thp_session.get_state(self.session_data)
pass
def get_channel_state(self) -> ChannelState:
state = int.from_bytes(self.channel_cache.state, "big")
return ChannelState(state)
def set_channel_state(self, state: ChannelState) -> None:
self.channel_cache.state = bytearray(state.value.to_bytes(1, "big"))
# CALLED BY THP_MAIN_LOOP
@ -54,14 +87,148 @@ class ChannelContext(Context):
else:
await self._handle_init_packet(packet)
if self.expected_payload_length == self.bytes_read:
self._finish_message()
await self._handle_completed_message()
async def _handle_init_packet(self, packet):
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
packet_payload = packet[5:]
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
packet_payload = self._decode(packet_payload)
packet_payload = self._decrypt(packet_payload)
# session_id = packet_payload[0] # TODO handle handshake differently
state = self.get_channel_state()
if state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
# TODO for now, we create a new big buffer every time. It should be changed
self.buffer = _get_buffer_for_payload(payload_length, self.buffer)
await self._buffer_packet_data(self.buffer, packet, 0)
async def _handle_cont_packet(self, packet):
if not self.is_cont_packet_expected:
return # Continuation packet is not expected, ignoring
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
async def _handle_completed_message(self):
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
msg_len = payload_length + _INIT_DATA_OFFSET
if not checksum.is_valid(
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
):
# checksum is not valid -> ignore message
self._todo_clear_buffer()
return
sync_bit = (ctrl_byte & 0x10) >> 4
if _is_ctrl_byte_ack(ctrl_byte):
self._handle_received_ACK(sync_bit)
self._todo_clear_buffer()
return
state = self.get_channel_state()
if state is ChannelState.TH1:
if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError(
"Message received is not a valid handshake init request!"
)
host_ephemeral_key = bytearray(
self.buffer[_INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(
self.channel_cache, host_ephemeral_key
)
# TODO send ack in response
# TODO send handshake init response message
self.set_channel_state(ChannelState.TH2)
return
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
# TODO ignore message
self._todo_clear_buffer()
return
if state is ChannelState.ENCRYPTED_TRANSPORT:
self._decrypt_buffer()
session_id, message_type = ustruct.unpack(
">BH", self.buffer[_INIT_DATA_OFFSET:]
)
if session_id not in self.sessions:
raise Exception("Unalloacted session")
session_state = self.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
raise Exception("Unalloacted session")
await self.sessions[session_id].receive_message(
message_type,
self.buffer[_INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
)
if state is ChannelState.TH2:
host_encrypted_static_pubkey = self.buffer[
_INIT_DATA_OFFSET : _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = self.buffer[
_INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
]
print(
host_encrypted_static_pubkey,
handshake_completion_request_noise_payload,
) # TODO remove
# TODO send ack in response
# TODO send hanshake completion response
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
def _decrypt(self, payload) -> bytes:
return payload # TODO add decryption process
def _decrypt_buffer(self) -> None:
pass
# TODO decode buffer in place
async def _buffer_packet_data(
self, payload_buffer, packet: utils.BufferType, offset
):
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self):
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
def _get_handler(self) -> Handler:
state = self.get_channel_state()
if state is ChannelState.UNAUTHENTICATED:
return self._handler_unauthenticated
if state is ChannelState.ENCRYPTED_TRANSPORT:
return self._handler_encrypted_transport
raise Exception("Unimplemented situation")
# Handlers for init packets
# TODO adjust
async def _handler_encrypted_transport(
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
) -> None:
self.expected_payload_length = payload_length
self.bytes_read = 0
@ -73,24 +240,20 @@ class ChannelContext(Context):
else:
self.is_cont_packet_expected = True
async def _handle_cont_packet(self, packet):
if not self.is_cont_packet_expected:
return # Continuation packet is not expected, ignoring
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
def _decode(self, payload) -> bytes:
return payload # TODO add decryption process
async def _buffer_packet_data(
self, payload_buffer, packet: utils.BufferType, offset
):
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self):
# TODO Provide loaded message to SessionContext or handle it with this ChannelContext
# TODO adjust
async def _handler_unauthenticated(
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
) -> None:
self.expected_payload_length = payload_length
self.bytes_read = 0
self.expected_payload_length = 0
self.is_cont_packet_expected = False
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
# TODO Set/Provide different buffer for management session
if self.expected_payload_length == self.bytes_read:
self._finish_message()
else:
self.is_cont_packet_expected = True
# CALLED BY WORKFLOW / SESSION CONTEXT
@ -105,6 +268,29 @@ class ChannelContext(Context):
pass
# create a new session with this passphrase
# OTHER
def _todo_clear_buffer(self):
raise NotImplementedError()
# TODO add debug logging to ACK handling
def _handle_received_ACK(self, sync_bit: int) -> None:
if self._ack_is_not_expected():
return
if self._ack_has_incorrect_sync_bit(sync_bit):
return
if self.waiting_for_ack_timeout is not None:
self.waiting_for_ack_timeout.close()
THP.sync_set_can_send_message(self.channel_cache, True)
def _ack_is_not_expected(self) -> bool:
return THP.sync_can_send_message(self.channel_cache)
def _ack_has_incorrect_sync_bit(self, sync_bit: int) -> bool:
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
def load_cached_channels() -> dict[int, ChannelContext]: # TODO
channels: dict[int, ChannelContext] = {}
@ -120,6 +306,8 @@ def _decode_iface(cached_iface: bytes) -> WireInterface:
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
raise NotImplementedError("Should return MockHID WireInterface")
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
@ -128,6 +316,8 @@ def _encode_iface(iface: WireInterface) -> bytes:
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
if __debug__:
return _MOCK_INTERFACE_HID
raise Exception("Unknown WireInterface")
@ -137,3 +327,29 @@ def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(payload_length)
except MemoryError:
payload = bytearray(_REPORT_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:payload_length]

View File

@ -5,7 +5,7 @@ from .channel_context import ChannelContext
def getPacketHandler(
channel: ChannelContext, packet: bytes
): # TODO is the packet bytes or BufferType?
if channel.get_management_session_state is ChannelState.TH1: # TODO is correct
if channel.get_channel_state is ChannelState.TH1: # TODO is correct
# return handler_TH_1
pass

View File

@ -1,14 +1,55 @@
from storage import cache_thp
from storage.cache_thp import SessionThpCache
from trezor import protobuf
from ..context import Context
from ..protocol_common import Context
from . import SessionState
from .channel_context import ChannelContext
class SessionContext(Context):
def __init__(self, channel_context: ChannelContext, session_id: int) -> None:
def __init__(
self, channel_context: ChannelContext, session_cache: SessionThpCache
) -> None:
if channel_context.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
super().__init__(channel_context.iface, channel_context.channel_id)
self.channel_context = channel_context
self.session_id = session_id
self.session_cache = session_cache
self.session_id = session_cache.session_id
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_context.write(msg, self.session_id)
return await self.channel_context.write(
msg, int.from_bytes(self.session_id, "big")
)
@classmethod
def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext":
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache)
# ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState:
state = int.from_bytes(self.session_cache.state, "big")
return SessionState(state)
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.value.to_bytes(1, "big"))
# Called by channel context
async def receive_message(self, message_type, encoded_protobuf_message):
pass # TODO implement
def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
for session in cached_sessions:
if session.channel_id == channel.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel, session)
return sessions

View File

@ -1,11 +1,14 @@
import ustruct
import ustruct # pyright:ignore[reportMissingModuleSource]
from storage.cache_thp import BROADCAST_CHANNEL_ID
from ..protocol_common import Message
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
ENCRYPTED_TRANSPORT = 0x02
HANDSHAKE_INIT = 0x00
ACK_MESSAGE = 0x20
_ERROR = 0x41
_CHANNEL_ALLOCATION_RES = 0x40
@ -63,9 +66,9 @@ def get_device_properties() -> Message:
return Message(_ENCODED_PROTOBUF_DEVICE_PROPERTIES)
def get_channel_allocation_response(nonce: bytes, new_cid: int) -> bytes:
def get_channel_allocation_response(nonce: bytes, new_cid: bytes) -> bytes:
props_msg = get_device_properties()
return ustruct.pack(">8sH", nonce, new_cid) + props_msg.to_bytes()
return nonce + new_cid + props_msg.to_bytes()
def get_error_unallocated_channel() -> bytes:

View File

@ -2,7 +2,7 @@ import ustruct
from typing import TYPE_CHECKING
from storage import cache_thp as storage_thp_cache
from storage.cache_thp import SessionThpCache
from storage.cache_thp import ChannelCache, SessionThpCache
from trezor.wire.protocol_common import WireError
if TYPE_CHECKING:
@ -62,40 +62,40 @@ def get_cid(session: SessionThpCache) -> int:
return storage_thp_cache._get_cid(session)
def get_next_channel_id() -> int: # deprecated TODO remove
return int.from_bytes(storage_thp_cache.get_next_channel_id(), "big")
def sync_can_send_message(cache: SessionThpCache | ChannelCache) -> bool:
return cache.sync & 0x80 == 0x80
def sync_can_send_message(session: SessionThpCache) -> bool:
return session.sync & 0x80 == 0x80
def sync_get_receive_expected_bit(cache: SessionThpCache | ChannelCache) -> int:
return (cache.sync & 0x40) >> 6
def sync_get_receive_expected_bit(session: SessionThpCache) -> int:
return (session.sync & 0x40) >> 6
def sync_get_send_bit(cache: SessionThpCache | ChannelCache) -> int:
return (cache.sync & 0x20) >> 5
def sync_get_send_bit(session: SessionThpCache) -> int:
return (session.sync & 0x20) >> 5
def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None:
session.sync &= 0x7F
def sync_set_can_send_message(
cache: SessionThpCache | ChannelCache, can_send: bool
) -> None:
cache.sync &= 0x7F
if can_send:
session.sync |= 0x80
cache.sync |= 0x80
def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
def sync_set_receive_expected_bit(
cache: SessionThpCache | ChannelCache, bit: int
) -> None:
if bit not in (0, 1):
raise ThpError("Unexpected receive sync bit")
# set second bit to "bit" value
session.sync &= 0xBF
cache.sync &= 0xBF
if bit:
session.sync |= 0x40
cache.sync |= 0x40
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:
_sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session))
def sync_set_send_bit_to_opposite(cache: SessionThpCache | ChannelCache) -> None:
_sync_set_send_bit(cache=cache, bit=1 - sync_get_send_bit(cache))
def is_active_session(session: SessionThpCache):
@ -126,13 +126,13 @@ def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache | None:
return None
def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None:
def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None:
if bit not in (0, 1):
raise ThpError("Unexpected send sync bit")
# set third bit to "bit" value
session.sync &= 0xDF
session.sync |= 0x20
cache.sync &= 0xDF
cache.sync |= 0x20
def _decode_session_state(state: bytearray) -> int:

View File

@ -1,16 +1,24 @@
import ustruct
from micropython import const
from typing import TYPE_CHECKING
import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
from trezor import io, log, loop, utils
from .protocol_common import MessageWithId
from .thp import ack_handler, checksum, thp_messages
from .thp import ChannelState, ack_handler, checksum, thp_messages
from .thp import thp_session as THP
from .thp.channel_context import ChannelContext, load_cached_channels
from .thp.channel_context import (
_INIT_DATA_OFFSET,
_MAX_PAYLOAD_LEN,
_REPORT_CONT_DATA_OFFSET,
_REPORT_LENGTH,
ChannelContext,
load_cached_channels,
)
from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import (
CODEC_V1,
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
InitHeader,
@ -19,18 +27,13 @@ from .thp.thp_messages import (
from .thp.thp_session import SessionState, ThpError
if TYPE_CHECKING:
from trezorio import WireInterface
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
_MAX_PAYLOAD_LEN = const(60000)
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
_CHANNEL_ALLOCATION_REQ = 0x40
_ACK_MESSAGE = 0x20
_HANDSHAKE_INIT = 0x00
_PLAINTEXT = 0x01
_REPORT_LENGTH = const(64)
_REPORT_INIT_DATA_OFFSET = const(5)
_REPORT_CONT_DATA_OFFSET = const(3)
_BUFFER: bytearray
_BUFFER_LOCK = None
@ -63,6 +66,12 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
packet = await read
ctrl_byte, cid = ustruct.unpack(">BH", packet)
if ctrl_byte == CODEC_V1:
pass
# TODO add handling of (unsupported) codec_v1 packets
# possibly ignore continuation packets, i.e. if the
# following bytes are not "##"", do not respond
if cid == BROADCAST_CHANNEL_ID:
# TODO handle exceptions, try-catch?
await _handle_broadcast(iface, ctrl_byte, packet)
@ -72,10 +81,10 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
channel = _CHANNEL_CONTEXTS[cid]
if channel is None:
raise ThpError("Invalid state of a channel")
# TODO if the channelContext interface is not None and is different from
# the one used in the transmission of the packet, raise an exception
# TODO add current wire interface to channelContext if its iface is None
if channel.get_management_session_state != SessionState.UNALLOCATED:
if channel.iface is not iface:
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
await channel.receive_packet(packet)
continue
@ -204,7 +213,7 @@ async def _buffer_received_data(
payload: utils.BufferType, header: InitHeader, iface, report
) -> None | InterruptingInitPacket:
# buffer the initial data
nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET)
nread = utils.memcpy(payload, 0, report, _INIT_DATA_OFFSET)
while nread < header.length:
# wait for continuation report
report = await _get_loop_wait_read(iface)
@ -297,7 +306,7 @@ async def write_to_wire(
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
nwritten = utils.memcpy(report, _INIT_DATA_OFFSET, payload, 0)
await _write_report(loop_write, iface, report)
# if we have more data to write, use continuation reports for it
@ -332,13 +341,13 @@ async def _handle_broadcast(
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid")
deprecated_channel_id = _get_new_channel_id() # TODO remove
THP.create_new_unauthenticated_session(iface, deprecated_channel_id) # TODO remove
new_context: ChannelContext = ChannelContext.create_new_channel(iface)
cid = int.from_bytes(new_context.channel_id, "big")
_CHANNEL_CONTEXTS[cid] = new_context
response_data = thp_messages.get_channel_allocation_response(nonce, cid)
response_data = thp_messages.get_channel_allocation_response(
nonce, new_context.channel_id
)
response_header = InitHeader.get_channel_allocation_response_header(
len(response_data) + CHECKSUM_LENGTH,
)
@ -389,10 +398,6 @@ async def _handle_unexpected_sync_bit(
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
def _get_new_channel_id() -> int:
return THP.get_next_channel_id()
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET