mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): improve channels, sessions and handshake
This commit is contained in:
parent
42873b1c30
commit
7c447ac5d1
@ -19,10 +19,11 @@ _MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove
|
||||
_CHANNEL_STATE_LENGTH = const(1)
|
||||
_WIRE_INTERFACE_LENGTH = const(1)
|
||||
_SESSION_STATE_LENGTH = const(1)
|
||||
_CHANNEL_ID_LENGTH = const(4)
|
||||
_SESSION_ID_LENGTH = const(4)
|
||||
_CHANNEL_ID_LENGTH = const(2)
|
||||
_SESSION_ID_LENGTH = const(1)
|
||||
BROADCAST_CHANNEL_ID = const(65535)
|
||||
|
||||
KEY_LENGTH = const(32)
|
||||
TAG_LENGTH = const(16)
|
||||
_UNALLOCATED_STATE = const(0)
|
||||
|
||||
|
||||
@ -40,10 +41,14 @@ class ConnectionCache(DataCache):
|
||||
|
||||
class ChannelCache(ConnectionCache):
|
||||
def __init__(self) -> None:
|
||||
self.enc_key = 0 # TODO change
|
||||
self.dec_key = 1 # TODO change
|
||||
self.host_ephemeral_pubkey = bytearray(KEY_LENGTH)
|
||||
self.enc_key = bytearray(KEY_LENGTH)
|
||||
self.dec_key = bytearray(KEY_LENGTH)
|
||||
self.state = bytearray(_CHANNEL_STATE_LENGTH)
|
||||
self.iface = bytearray(1) # TODO add decoding
|
||||
self.sync = 0x80 # can_send_bit | sync_receive_bit | sync_send_bit | rfu(5)
|
||||
self.session_id_counter = 0x01
|
||||
self.fields = ()
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
@ -135,13 +140,20 @@ def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache:
|
||||
new_cid = get_next_channel_id()
|
||||
index = _get_next_unauthenticated_channel_index()
|
||||
|
||||
# clear sessions from replaced channel
|
||||
if _get_channel_state(_CHANNELS[index]) != _UNALLOCATED_STATE:
|
||||
old_cid = _CHANNELS[index].channel_id
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == old_cid:
|
||||
session.clear()
|
||||
|
||||
_CHANNELS[index] = ChannelCache()
|
||||
_CHANNELS[index].channel_id[:] = new_cid
|
||||
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
|
||||
_CHANNELS[index].state = bytearray(
|
||||
_CHANNELS[index].state[:] = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
|
||||
)
|
||||
_CHANNELS[index].iface = bytearray(iface)
|
||||
_CHANNELS[index].iface[:] = bytearray(iface)
|
||||
return _CHANNELS[index]
|
||||
|
||||
|
||||
@ -153,6 +165,38 @@ def get_all_allocated_channels() -> list[ChannelCache]:
|
||||
return _list
|
||||
|
||||
|
||||
def get_all_allocated_sessions() -> list[SessionThpCache]:
|
||||
_list: list[SessionThpCache] = []
|
||||
for session in _SESSIONS:
|
||||
if _get_session_state(session) != _UNALLOCATED_STATE:
|
||||
_list.append(session)
|
||||
return _list
|
||||
|
||||
|
||||
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
|
||||
if len(key) != KEY_LENGTH:
|
||||
raise Exception("Invalid key length")
|
||||
channel.host_ephemeral_pubkey = key
|
||||
|
||||
|
||||
def get_new_session(channel: ChannelCache):
|
||||
|
||||
new_sid = get_next_session_id(channel)
|
||||
index = _get_next_session_index()
|
||||
|
||||
_SESSIONS[index] = SessionThpCache()
|
||||
_SESSIONS[index].channel_id[:] = channel.channel_id
|
||||
_SESSIONS[index].session_id[:] = new_sid
|
||||
_SESSIONS[index].last_usage = _get_usage_counter_and_increment()
|
||||
channel.last_usage = (
|
||||
_get_usage_counter_and_increment()
|
||||
) # increment also use of the channel so it does not get replaced
|
||||
_SESSIONS[index].state[:] = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big")
|
||||
)
|
||||
return _SESSIONS[index]
|
||||
|
||||
|
||||
def _get_usage_counter() -> int:
|
||||
global _usage_counter
|
||||
return _usage_counter
|
||||
@ -171,6 +215,13 @@ def _get_next_unauthenticated_channel_index() -> int:
|
||||
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
|
||||
|
||||
|
||||
def _get_next_session_index() -> int:
|
||||
idx = _get_unallocated_session_index()
|
||||
if idx is not None:
|
||||
return idx
|
||||
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
|
||||
|
||||
|
||||
def _get_unallocated_channel_index() -> int | None:
|
||||
for i in range(_MAX_CHANNELS_COUNT):
|
||||
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
|
||||
@ -178,12 +229,25 @@ def _get_unallocated_channel_index() -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def _get_unallocated_session_index() -> int | None:
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if (_SESSIONS[i]) is _UNALLOCATED_STATE:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def _get_channel_state(channel: ChannelCache) -> int:
|
||||
if channel is None:
|
||||
return _UNALLOCATED_STATE
|
||||
return int.from_bytes(channel.state, "big")
|
||||
|
||||
|
||||
def _get_session_state(session: SessionThpCache) -> int:
|
||||
if session is None:
|
||||
return _UNALLOCATED_STATE
|
||||
return int.from_bytes(session.state, "big")
|
||||
|
||||
|
||||
def get_active_session_id() -> bytearray | None:
|
||||
active_session = get_active_session()
|
||||
|
||||
@ -200,9 +264,6 @@ def get_active_session() -> SessionThpCache | None:
|
||||
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def get_next_channel_id() -> bytes:
|
||||
global cid_counter
|
||||
while True:
|
||||
@ -214,6 +275,25 @@ def get_next_channel_id() -> bytes:
|
||||
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def get_next_session_id(channel: ChannelCache) -> bytes:
|
||||
while not _is_session_id_unique(channel):
|
||||
if channel.session_id_counter >= 255:
|
||||
channel.session_id_counter = 1
|
||||
else:
|
||||
channel.session_id_counter += 1
|
||||
new_sid = channel.session_id_counter
|
||||
channel.session_id_counter += 1
|
||||
return new_sid.to_bytes(_SESSION_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def _is_session_id_unique(channel: ChannelCache) -> bool:
|
||||
for session in _SESSIONS:
|
||||
if session.channel_id == channel.channel_id:
|
||||
if session.session_id == channel.session_id_counter:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _is_cid_unique() -> bool:
|
||||
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
||||
if cid_counter == _get_cid(session):
|
||||
@ -226,8 +306,10 @@ def _get_cid(session: SessionThpCache) -> int:
|
||||
|
||||
|
||||
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
|
||||
if len(session_id) != 4:
|
||||
raise ValueError("session_id must be 4 bytes long.")
|
||||
if len(session_id) != _SESSION_ID_LENGTH:
|
||||
raise ValueError(
|
||||
"session_id must be X bytes long, where X=", _SESSION_ID_LENGTH
|
||||
)
|
||||
global _active_session_idx
|
||||
global _is_active_session_authenticated
|
||||
global _next_unauthenicated_session_index
|
||||
@ -302,7 +384,10 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
new_session_id = b"\x00\x00" + get_next_channel_id()
|
||||
|
||||
channel = get_new_unauthenticated_channel(b"\x00")
|
||||
|
||||
new_session_id = get_next_session_id(channel)
|
||||
|
||||
new_session = create_new_unauthenticated_session(new_session_id)
|
||||
|
||||
|
@ -40,7 +40,13 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Any, Callable, Container, Coroutine, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
Coroutine,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
@ -53,10 +59,11 @@ EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
"""Initialize the wire stack on passed USB interface."""
|
||||
loop.schedule(
|
||||
handle_session(iface, codec_v1.SESSION_ID.to_bytes(4, "big"), is_debug_session)
|
||||
)
|
||||
"""Initialize the wire stack on passed WireInterface."""
|
||||
if utils.USE_THP:
|
||||
loop.schedule(handle_thp_session(iface, is_debug_session))
|
||||
else:
|
||||
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
@ -128,13 +135,13 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
|
||||
|
||||
|
||||
async def handle_session(
|
||||
iface: WireInterface, session_id: bytes, is_debug_session: bool = False
|
||||
iface: WireInterface, codec_session_id: int, is_debug_session: bool = False
|
||||
) -> None:
|
||||
if __debug__ and is_debug_session:
|
||||
ctx_buffer = WIRE_BUFFER_DEBUG
|
||||
else:
|
||||
ctx_buffer = WIRE_BUFFER
|
||||
|
||||
session_id = codec_session_id.to_bytes(4, "big")
|
||||
ctx = context.CodecContext(iface, ctx_buffer, session_id)
|
||||
next_msg: protocol_common.MessageWithId | None = None
|
||||
|
||||
|
@ -15,7 +15,13 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Any, Callable, Container, Coroutine, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
Coroutine,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
|
@ -41,13 +41,13 @@ class MessageWithId(MessageWithType):
|
||||
super().__init__(message_type, message_data)
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, iface: WireInterface, channel_id: bytes) -> None:
|
||||
self.iface: WireInterface = iface
|
||||
self.channel_id: bytes = channel_id
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||
|
||||
|
||||
class WireError(Exception):
|
||||
pass
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import IntEnum
|
||||
@ -19,6 +19,10 @@ class ChannelState(IntEnum):
|
||||
ENCRYPTED_TRANSPORT = 9
|
||||
|
||||
|
||||
class SessionState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
|
||||
|
||||
class WireInterfaceType(IntEnum):
|
||||
MOCK = 0
|
||||
USB = 1
|
||||
|
@ -4,26 +4,28 @@ from trezor import log
|
||||
from . import thp_session as THP
|
||||
|
||||
|
||||
def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
|
||||
def handle_received_ACK(cache: SessionThpCache, sync_bit: int) -> None:
|
||||
|
||||
if _ack_is_not_expected(session):
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received unexpected ACK message")
|
||||
if _ack_is_not_expected(cache):
|
||||
_conditionally_log_debug(__name__, "Received unexpected ACK message")
|
||||
return
|
||||
if _ack_has_incorrect_sync_bit(session, sync_bit):
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received ACK message with wrong sync bit")
|
||||
if _ack_has_incorrect_sync_bit(cache, sync_bit):
|
||||
_conditionally_log_debug(__name__, "Received ACK message with wrong sync bit")
|
||||
return
|
||||
|
||||
# ACK is expected and it has correct sync bit
|
||||
_conditionally_log_debug(__name__, "Received ACK message with correct sync bit")
|
||||
THP.sync_set_can_send_message(cache, True)
|
||||
|
||||
|
||||
def _ack_is_not_expected(cache: SessionThpCache) -> bool:
|
||||
return THP.sync_can_send_message(cache)
|
||||
|
||||
|
||||
def _ack_has_incorrect_sync_bit(cache: SessionThpCache, sync_bit: int) -> bool:
|
||||
return THP.sync_get_send_bit(cache) != sync_bit
|
||||
|
||||
|
||||
def _conditionally_log_debug(name, message):
|
||||
if __debug__:
|
||||
log.debug(__name__, "Received ACK message with correct sync bit")
|
||||
THP.sync_set_can_send_message(session, True)
|
||||
|
||||
|
||||
def _ack_is_not_expected(session: SessionThpCache) -> bool:
|
||||
return THP.sync_can_send_message(session)
|
||||
|
||||
|
||||
def _ack_has_incorrect_sync_bit(session: SessionThpCache, sync_bit: int) -> bool:
|
||||
return THP.sync_get_send_bit(session) != sync_bit
|
||||
log.debug(name, message)
|
||||
|
@ -1,26 +1,53 @@
|
||||
import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import ( # pyright:ignore[reportShadowedImports]
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
)
|
||||
|
||||
import usb
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import ChannelCache
|
||||
from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache
|
||||
from trezor import loop, protobuf, utils
|
||||
|
||||
from ..protocol_common import Context
|
||||
from . import ChannelState, SessionState, checksum
|
||||
from . import thp_session as THP
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
|
||||
# from . import thp_session
|
||||
from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT
|
||||
from .thp_messages import (
|
||||
ACK_MESSAGE,
|
||||
CONTINUATION_PACKET,
|
||||
ENCRYPTED_TRANSPORT,
|
||||
HANDSHAKE_INIT,
|
||||
)
|
||||
from .thp_session import ThpError
|
||||
|
||||
# from .thp_session import SessionState, ThpError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from trezorio import WireInterface # type:ignore
|
||||
|
||||
Handler = Callable[
|
||||
[bytes, Any, Any, Any], Coroutine
|
||||
] # TODO Adjust parameters to be more restrictive
|
||||
|
||||
|
||||
_INIT_DATA_OFFSET = const(5)
|
||||
_CONT_DATA_OFFSET = const(3)
|
||||
_INIT_DATA_OFFSET = const(5)
|
||||
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||
|
||||
_WIRE_INTERFACE_USB = b"\x00"
|
||||
_WIRE_INTERFACE_USB = b"\x01"
|
||||
_MOCK_INTERFACE_HID = b"\x00"
|
||||
|
||||
_PUBKEY_LENGTH = const(32)
|
||||
|
||||
_REPORT_LENGTH = const(64)
|
||||
_MAX_PAYLOAD_LEN = const(60000)
|
||||
|
||||
|
||||
class ChannelContext(Context):
|
||||
@ -33,6 +60,9 @@ class ChannelContext(Context):
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.expected_payload_length: int = 0
|
||||
self.bytes_read = 0
|
||||
from trezor.wire.thp.session_context import load_cached_sessions
|
||||
|
||||
self.sessions = load_cached_sessions(self)
|
||||
|
||||
@classmethod
|
||||
def create_new_channel(cls, iface: WireInterface) -> "ChannelContext":
|
||||
@ -41,9 +71,12 @@ class ChannelContext(Context):
|
||||
|
||||
# ACCESS TO CHANNEL_DATA
|
||||
|
||||
def get_management_session_state(self): # TODO redo for channel state
|
||||
# return thp_session.get_state(self.session_data)
|
||||
pass
|
||||
def get_channel_state(self) -> ChannelState:
|
||||
state = int.from_bytes(self.channel_cache.state, "big")
|
||||
return ChannelState(state)
|
||||
|
||||
def set_channel_state(self, state: ChannelState) -> None:
|
||||
self.channel_cache.state = bytearray(state.value.to_bytes(1, "big"))
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
@ -54,14 +87,148 @@ class ChannelContext(Context):
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
|
||||
if self.expected_payload_length == self.bytes_read:
|
||||
self._finish_message()
|
||||
await self._handle_completed_message()
|
||||
|
||||
async def _handle_init_packet(self, packet):
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||
packet_payload = packet[5:]
|
||||
|
||||
# If the channel does not "own" the buffer lock, decrypt first packet
|
||||
# TODO do it only when needed!
|
||||
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
||||
packet_payload = self._decode(packet_payload)
|
||||
packet_payload = self._decrypt(packet_payload)
|
||||
|
||||
# session_id = packet_payload[0] # TODO handle handshake differently
|
||||
state = self.get_channel_state()
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
session_id = packet_payload[0]
|
||||
if session_id == 0:
|
||||
pass
|
||||
# TODO use small buffer
|
||||
else:
|
||||
pass
|
||||
# TODO use big buffer but only if the channel owns the buffer lock.
|
||||
# Otherwise send BUSY message and return
|
||||
else:
|
||||
pass
|
||||
# TODO use small buffer
|
||||
|
||||
# TODO for now, we create a new big buffer every time. It should be changed
|
||||
self.buffer = _get_buffer_for_payload(payload_length, self.buffer)
|
||||
|
||||
await self._buffer_packet_data(self.buffer, packet, 0)
|
||||
|
||||
async def _handle_cont_packet(self, packet):
|
||||
if not self.is_cont_packet_expected:
|
||||
return # Continuation packet is not expected, ignoring
|
||||
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
|
||||
|
||||
async def _handle_completed_message(self):
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
||||
msg_len = payload_length + _INIT_DATA_OFFSET
|
||||
if not checksum.is_valid(
|
||||
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
|
||||
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
|
||||
):
|
||||
# checksum is not valid -> ignore message
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
self._handle_received_ACK(sync_bit)
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
|
||||
state = self.get_channel_state()
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
if not _is_ctrl_byte_handshake_init:
|
||||
raise ThpError("Message received is not a handshake init request!")
|
||||
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
raise ThpError(
|
||||
"Message received is not a valid handshake init request!"
|
||||
)
|
||||
host_ephemeral_key = bytearray(
|
||||
self.buffer[_INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
|
||||
)
|
||||
cache_thp.set_channel_host_ephemeral_key(
|
||||
self.channel_cache, host_ephemeral_key
|
||||
)
|
||||
# TODO send ack in response
|
||||
# TODO send handshake init response message
|
||||
self.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
|
||||
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
||||
# TODO ignore message
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
self._decrypt_buffer()
|
||||
session_id, message_type = ustruct.unpack(
|
||||
">BH", self.buffer[_INIT_DATA_OFFSET:]
|
||||
)
|
||||
if session_id not in self.sessions:
|
||||
raise Exception("Unalloacted session")
|
||||
|
||||
session_state = self.sessions[session_id].get_session_state()
|
||||
if session_state is SessionState.UNALLOCATED:
|
||||
raise Exception("Unalloacted session")
|
||||
|
||||
await self.sessions[session_id].receive_message(
|
||||
message_type,
|
||||
self.buffer[_INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
||||
)
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
host_encrypted_static_pubkey = self.buffer[
|
||||
_INIT_DATA_OFFSET : _INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
handshake_completion_request_noise_payload = self.buffer[
|
||||
_INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
|
||||
]
|
||||
print(
|
||||
host_encrypted_static_pubkey,
|
||||
handshake_completion_request_noise_payload,
|
||||
) # TODO remove
|
||||
# TODO send ack in response
|
||||
# TODO send hanshake completion response
|
||||
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
|
||||
def _decrypt(self, payload) -> bytes:
|
||||
return payload # TODO add decryption process
|
||||
|
||||
def _decrypt_buffer(self) -> None:
|
||||
pass
|
||||
# TODO decode buffer in place
|
||||
|
||||
async def _buffer_packet_data(
|
||||
self, payload_buffer, packet: utils.BufferType, offset
|
||||
):
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
|
||||
def _finish_message(self):
|
||||
self.bytes_read = 0
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
def _get_handler(self) -> Handler:
|
||||
state = self.get_channel_state()
|
||||
if state is ChannelState.UNAUTHENTICATED:
|
||||
return self._handler_unauthenticated
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
return self._handler_encrypted_transport
|
||||
raise Exception("Unimplemented situation")
|
||||
|
||||
# Handlers for init packets
|
||||
# TODO adjust
|
||||
async def _handler_encrypted_transport(
|
||||
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
|
||||
) -> None:
|
||||
self.expected_payload_length = payload_length
|
||||
self.bytes_read = 0
|
||||
|
||||
@ -73,24 +240,20 @@ class ChannelContext(Context):
|
||||
else:
|
||||
self.is_cont_packet_expected = True
|
||||
|
||||
async def _handle_cont_packet(self, packet):
|
||||
if not self.is_cont_packet_expected:
|
||||
return # Continuation packet is not expected, ignoring
|
||||
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
|
||||
|
||||
def _decode(self, payload) -> bytes:
|
||||
return payload # TODO add decryption process
|
||||
|
||||
async def _buffer_packet_data(
|
||||
self, payload_buffer, packet: utils.BufferType, offset
|
||||
):
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
|
||||
def _finish_message(self):
|
||||
# TODO Provide loaded message to SessionContext or handle it with this ChannelContext
|
||||
# TODO adjust
|
||||
async def _handler_unauthenticated(
|
||||
self, ctrl_byte: bytes, payload_length: int, packet_payload: bytes, packet
|
||||
) -> None:
|
||||
self.expected_payload_length = payload_length
|
||||
self.bytes_read = 0
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
|
||||
# TODO Set/Provide different buffer for management session
|
||||
|
||||
if self.expected_payload_length == self.bytes_read:
|
||||
self._finish_message()
|
||||
else:
|
||||
self.is_cont_packet_expected = True
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
@ -105,6 +268,29 @@ class ChannelContext(Context):
|
||||
pass
|
||||
# create a new session with this passphrase
|
||||
|
||||
# OTHER
|
||||
|
||||
def _todo_clear_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
# TODO add debug logging to ACK handling
|
||||
def _handle_received_ACK(self, sync_bit: int) -> None:
|
||||
if self._ack_is_not_expected():
|
||||
return
|
||||
if self._ack_has_incorrect_sync_bit(sync_bit):
|
||||
return
|
||||
|
||||
if self.waiting_for_ack_timeout is not None:
|
||||
self.waiting_for_ack_timeout.close()
|
||||
|
||||
THP.sync_set_can_send_message(self.channel_cache, True)
|
||||
|
||||
def _ack_is_not_expected(self) -> bool:
|
||||
return THP.sync_can_send_message(self.channel_cache)
|
||||
|
||||
def _ack_has_incorrect_sync_bit(self, sync_bit: int) -> bool:
|
||||
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
|
||||
|
||||
|
||||
def load_cached_channels() -> dict[int, ChannelContext]: # TODO
|
||||
channels: dict[int, ChannelContext] = {}
|
||||
@ -120,6 +306,8 @@ def _decode_iface(cached_iface: bytes) -> WireInterface:
|
||||
if iface is None:
|
||||
raise RuntimeError("There is no valid USB WireInterface")
|
||||
return iface
|
||||
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
|
||||
raise NotImplementedError("Should return MockHID WireInterface")
|
||||
# TODO implement bluetooth interface
|
||||
raise Exception("Unknown WireInterface")
|
||||
|
||||
@ -128,6 +316,8 @@ def _encode_iface(iface: WireInterface) -> bytes:
|
||||
if iface is usb.iface_wire:
|
||||
return _WIRE_INTERFACE_USB
|
||||
# TODO implement bluetooth interface
|
||||
if __debug__:
|
||||
return _MOCK_INTERFACE_HID
|
||||
raise Exception("Unknown WireInterface")
|
||||
|
||||
|
||||
@ -137,3 +327,29 @@ def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
|
||||
|
||||
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
|
||||
|
||||
|
||||
def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == HANDSHAKE_INIT
|
||||
|
||||
|
||||
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == ACK_MESSAGE
|
||||
|
||||
|
||||
def _get_buffer_for_payload(
|
||||
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN
|
||||
) -> utils.BufferType:
|
||||
if payload_length > max_length:
|
||||
raise ThpError("Message too large")
|
||||
if payload_length > len(existing_buffer):
|
||||
# allocate a new buffer to fit the message
|
||||
try:
|
||||
payload: utils.BufferType = bytearray(payload_length)
|
||||
except MemoryError:
|
||||
payload = bytearray(_REPORT_LENGTH)
|
||||
raise ThpError("Message too large")
|
||||
return payload
|
||||
|
||||
# reuse a part of the supplied buffer
|
||||
return memoryview(existing_buffer)[:payload_length]
|
||||
|
@ -5,7 +5,7 @@ from .channel_context import ChannelContext
|
||||
def getPacketHandler(
|
||||
channel: ChannelContext, packet: bytes
|
||||
): # TODO is the packet bytes or BufferType?
|
||||
if channel.get_management_session_state is ChannelState.TH1: # TODO is correct
|
||||
if channel.get_channel_state is ChannelState.TH1: # TODO is correct
|
||||
# return handler_TH_1
|
||||
pass
|
||||
|
||||
|
@ -1,14 +1,55 @@
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import protobuf
|
||||
|
||||
from ..context import Context
|
||||
from ..protocol_common import Context
|
||||
from . import SessionState
|
||||
from .channel_context import ChannelContext
|
||||
|
||||
|
||||
class SessionContext(Context):
|
||||
def __init__(self, channel_context: ChannelContext, session_id: int) -> None:
|
||||
def __init__(
|
||||
self, channel_context: ChannelContext, session_cache: SessionThpCache
|
||||
) -> None:
|
||||
if channel_context.channel_id != session_cache.channel_id:
|
||||
raise Exception(
|
||||
"The session has different channel id than the provided channel context!"
|
||||
)
|
||||
super().__init__(channel_context.iface, channel_context.channel_id)
|
||||
self.channel_context = channel_context
|
||||
self.session_id = session_id
|
||||
self.session_cache = session_cache
|
||||
self.session_id = session_cache.session_id
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel_context.write(msg, self.session_id)
|
||||
return await self.channel_context.write(
|
||||
msg, int.from_bytes(self.session_id, "big")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext":
|
||||
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
|
||||
return cls(channel_context, session_cache)
|
||||
|
||||
# ACCESS TO SESSION DATA
|
||||
|
||||
def get_session_state(self) -> SessionState:
|
||||
state = int.from_bytes(self.session_cache.state, "big")
|
||||
return SessionState(state)
|
||||
|
||||
def set_session_state(self, state: SessionState) -> None:
|
||||
self.session_cache.state = bytearray(state.value.to_bytes(1, "big"))
|
||||
|
||||
# Called by channel context
|
||||
|
||||
async def receive_message(self, message_type, encoded_protobuf_message):
|
||||
pass # TODO implement
|
||||
|
||||
|
||||
def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO
|
||||
sessions: dict[int, SessionContext] = {}
|
||||
cached_sessions = cache_thp.get_all_allocated_sessions()
|
||||
for session in cached_sessions:
|
||||
if session.channel_id == channel.channel_id:
|
||||
sid = int.from_bytes(session.session_id, "big")
|
||||
sessions[sid] = SessionContext(channel, session)
|
||||
return sessions
|
||||
|
@ -1,11 +1,14 @@
|
||||
import ustruct
|
||||
import ustruct # pyright:ignore[reportMissingModuleSource]
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
|
||||
from ..protocol_common import Message
|
||||
|
||||
CODEC_V1 = 0x3F
|
||||
CONTINUATION_PACKET = 0x80
|
||||
ENCRYPTED_TRANSPORT = 0x02
|
||||
HANDSHAKE_INIT = 0x00
|
||||
ACK_MESSAGE = 0x20
|
||||
_ERROR = 0x41
|
||||
_CHANNEL_ALLOCATION_RES = 0x40
|
||||
|
||||
@ -63,9 +66,9 @@ def get_device_properties() -> Message:
|
||||
return Message(_ENCODED_PROTOBUF_DEVICE_PROPERTIES)
|
||||
|
||||
|
||||
def get_channel_allocation_response(nonce: bytes, new_cid: int) -> bytes:
|
||||
def get_channel_allocation_response(nonce: bytes, new_cid: bytes) -> bytes:
|
||||
props_msg = get_device_properties()
|
||||
return ustruct.pack(">8sH", nonce, new_cid) + props_msg.to_bytes()
|
||||
return nonce + new_cid + props_msg.to_bytes()
|
||||
|
||||
|
||||
def get_error_unallocated_channel() -> bytes:
|
||||
|
@ -2,7 +2,7 @@ import ustruct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage import cache_thp as storage_thp_cache
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from storage.cache_thp import ChannelCache, SessionThpCache
|
||||
from trezor.wire.protocol_common import WireError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -62,40 +62,40 @@ def get_cid(session: SessionThpCache) -> int:
|
||||
return storage_thp_cache._get_cid(session)
|
||||
|
||||
|
||||
def get_next_channel_id() -> int: # deprecated TODO remove
|
||||
return int.from_bytes(storage_thp_cache.get_next_channel_id(), "big")
|
||||
def sync_can_send_message(cache: SessionThpCache | ChannelCache) -> bool:
|
||||
return cache.sync & 0x80 == 0x80
|
||||
|
||||
|
||||
def sync_can_send_message(session: SessionThpCache) -> bool:
|
||||
return session.sync & 0x80 == 0x80
|
||||
def sync_get_receive_expected_bit(cache: SessionThpCache | ChannelCache) -> int:
|
||||
return (cache.sync & 0x40) >> 6
|
||||
|
||||
|
||||
def sync_get_receive_expected_bit(session: SessionThpCache) -> int:
|
||||
return (session.sync & 0x40) >> 6
|
||||
def sync_get_send_bit(cache: SessionThpCache | ChannelCache) -> int:
|
||||
return (cache.sync & 0x20) >> 5
|
||||
|
||||
|
||||
def sync_get_send_bit(session: SessionThpCache) -> int:
|
||||
return (session.sync & 0x20) >> 5
|
||||
|
||||
|
||||
def sync_set_can_send_message(session: SessionThpCache, can_send: bool) -> None:
|
||||
session.sync &= 0x7F
|
||||
def sync_set_can_send_message(
|
||||
cache: SessionThpCache | ChannelCache, can_send: bool
|
||||
) -> None:
|
||||
cache.sync &= 0x7F
|
||||
if can_send:
|
||||
session.sync |= 0x80
|
||||
cache.sync |= 0x80
|
||||
|
||||
|
||||
def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
|
||||
def sync_set_receive_expected_bit(
|
||||
cache: SessionThpCache | ChannelCache, bit: int
|
||||
) -> None:
|
||||
if bit not in (0, 1):
|
||||
raise ThpError("Unexpected receive sync bit")
|
||||
|
||||
# set second bit to "bit" value
|
||||
session.sync &= 0xBF
|
||||
cache.sync &= 0xBF
|
||||
if bit:
|
||||
session.sync |= 0x40
|
||||
cache.sync |= 0x40
|
||||
|
||||
|
||||
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:
|
||||
_sync_set_send_bit(session=session, bit=1 - sync_get_send_bit(session))
|
||||
def sync_set_send_bit_to_opposite(cache: SessionThpCache | ChannelCache) -> None:
|
||||
_sync_set_send_bit(cache=cache, bit=1 - sync_get_send_bit(cache))
|
||||
|
||||
|
||||
def is_active_session(session: SessionThpCache):
|
||||
@ -126,13 +126,13 @@ def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache | None:
|
||||
return None
|
||||
|
||||
|
||||
def _sync_set_send_bit(session: SessionThpCache, bit: int) -> None:
|
||||
def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None:
|
||||
if bit not in (0, 1):
|
||||
raise ThpError("Unexpected send sync bit")
|
||||
|
||||
# set third bit to "bit" value
|
||||
session.sync &= 0xDF
|
||||
session.sync |= 0x20
|
||||
cache.sync &= 0xDF
|
||||
cache.sync |= 0x20
|
||||
|
||||
|
||||
def _decode_session_state(state: bytearray) -> int:
|
||||
|
@ -1,16 +1,24 @@
|
||||
import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
|
||||
from trezor import io, log, loop, utils
|
||||
|
||||
from .protocol_common import MessageWithId
|
||||
from .thp import ack_handler, checksum, thp_messages
|
||||
from .thp import ChannelState, ack_handler, checksum, thp_messages
|
||||
from .thp import thp_session as THP
|
||||
from .thp.channel_context import ChannelContext, load_cached_channels
|
||||
from .thp.channel_context import (
|
||||
_INIT_DATA_OFFSET,
|
||||
_MAX_PAYLOAD_LEN,
|
||||
_REPORT_CONT_DATA_OFFSET,
|
||||
_REPORT_LENGTH,
|
||||
ChannelContext,
|
||||
load_cached_channels,
|
||||
)
|
||||
from .thp.checksum import CHECKSUM_LENGTH
|
||||
from .thp.thp_messages import (
|
||||
CODEC_V1,
|
||||
CONTINUATION_PACKET,
|
||||
ENCRYPTED_TRANSPORT,
|
||||
InitHeader,
|
||||
@ -19,18 +27,13 @@ from .thp.thp_messages import (
|
||||
from .thp.thp_session import SessionState, ThpError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
|
||||
|
||||
_MAX_PAYLOAD_LEN = const(60000)
|
||||
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
|
||||
_CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_ACK_MESSAGE = 0x20
|
||||
_HANDSHAKE_INIT = 0x00
|
||||
_PLAINTEXT = 0x01
|
||||
|
||||
_REPORT_LENGTH = const(64)
|
||||
_REPORT_INIT_DATA_OFFSET = const(5)
|
||||
_REPORT_CONT_DATA_OFFSET = const(3)
|
||||
|
||||
_BUFFER: bytearray
|
||||
_BUFFER_LOCK = None
|
||||
@ -63,6 +66,12 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
packet = await read
|
||||
ctrl_byte, cid = ustruct.unpack(">BH", packet)
|
||||
|
||||
if ctrl_byte == CODEC_V1:
|
||||
pass
|
||||
# TODO add handling of (unsupported) codec_v1 packets
|
||||
# possibly ignore continuation packets, i.e. if the
|
||||
# following bytes are not "##"", do not respond
|
||||
|
||||
if cid == BROADCAST_CHANNEL_ID:
|
||||
# TODO handle exceptions, try-catch?
|
||||
await _handle_broadcast(iface, ctrl_byte, packet)
|
||||
@ -72,10 +81,10 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
|
||||
channel = _CHANNEL_CONTEXTS[cid]
|
||||
if channel is None:
|
||||
raise ThpError("Invalid state of a channel")
|
||||
# TODO if the channelContext interface is not None and is different from
|
||||
# the one used in the transmission of the packet, raise an exception
|
||||
# TODO add current wire interface to channelContext if its iface is None
|
||||
if channel.get_management_session_state != SessionState.UNALLOCATED:
|
||||
if channel.iface is not iface:
|
||||
raise ThpError("Channel has different WireInterface")
|
||||
|
||||
if channel.get_channel_state() != ChannelState.UNALLOCATED:
|
||||
await channel.receive_packet(packet)
|
||||
continue
|
||||
|
||||
@ -204,7 +213,7 @@ async def _buffer_received_data(
|
||||
payload: utils.BufferType, header: InitHeader, iface, report
|
||||
) -> None | InterruptingInitPacket:
|
||||
# buffer the initial data
|
||||
nread = utils.memcpy(payload, 0, report, _REPORT_INIT_DATA_OFFSET)
|
||||
nread = utils.memcpy(payload, 0, report, _INIT_DATA_OFFSET)
|
||||
while nread < header.length:
|
||||
# wait for continuation report
|
||||
report = await _get_loop_wait_read(iface)
|
||||
@ -297,7 +306,7 @@ async def write_to_wire(
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, _REPORT_INIT_DATA_OFFSET, payload, 0)
|
||||
nwritten = utils.memcpy(report, _INIT_DATA_OFFSET, payload, 0)
|
||||
await _write_report(loop_write, iface, report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
@ -332,13 +341,13 @@ async def _handle_broadcast(
|
||||
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
||||
raise ThpError("Checksum is not valid")
|
||||
|
||||
deprecated_channel_id = _get_new_channel_id() # TODO remove
|
||||
THP.create_new_unauthenticated_session(iface, deprecated_channel_id) # TODO remove
|
||||
new_context: ChannelContext = ChannelContext.create_new_channel(iface)
|
||||
cid = int.from_bytes(new_context.channel_id, "big")
|
||||
_CHANNEL_CONTEXTS[cid] = new_context
|
||||
|
||||
response_data = thp_messages.get_channel_allocation_response(nonce, cid)
|
||||
response_data = thp_messages.get_channel_allocation_response(
|
||||
nonce, new_context.channel_id
|
||||
)
|
||||
response_header = InitHeader.get_channel_allocation_response_header(
|
||||
len(response_data) + CHECKSUM_LENGTH,
|
||||
)
|
||||
@ -389,10 +398,6 @@ async def _handle_unexpected_sync_bit(
|
||||
# (some such messages might be handled in the classical "allocated" way, if the sync bit is right)
|
||||
|
||||
|
||||
def _get_new_channel_id() -> int:
|
||||
return THP.get_next_channel_id()
|
||||
|
||||
|
||||
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
||||
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user