Channel refactor

M1nd3r/thp5
M1nd3r 1 month ago
parent 531176570d
commit ea6cb32e59

@ -213,16 +213,30 @@ trezor.wire.thp.ack_handler
import trezor.wire.thp.ack_handler import trezor.wire.thp.ack_handler
trezor.wire.thp.channel trezor.wire.thp.channel
import trezor.wire.thp.channel import trezor.wire.thp.channel
trezor.wire.thp.channel_manager
import trezor.wire.thp.channel_manager
trezor.wire.thp.checksum trezor.wire.thp.checksum
import trezor.wire.thp.checksum import trezor.wire.thp.checksum
trezor.wire.thp.control_byte
import trezor.wire.thp.control_byte
trezor.wire.thp.crypto trezor.wire.thp.crypto
import trezor.wire.thp.crypto import trezor.wire.thp.crypto
trezor.wire.thp.handler_provider trezor.wire.thp.handler_provider
import trezor.wire.thp.handler_provider import trezor.wire.thp.handler_provider
trezor.wire.thp.interface_manager
import trezor.wire.thp.interface_manager
trezor.wire.thp.memory_manager
import trezor.wire.thp.memory_manager
trezor.wire.thp.pairing_context trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler
trezor.wire.thp.retransmission
import trezor.wire.thp.retransmission
trezor.wire.thp.session_context trezor.wire.thp.session_context
import trezor.wire.thp.session_context import trezor.wire.thp.session_context
trezor.wire.thp.session_manager
import trezor.wire.thp.session_manager
trezor.wire.thp.thp_messages trezor.wire.thp.thp_messages
import trezor.wire.thp.thp_messages import trezor.wire.thp.thp_messages
trezor.wire.thp.thp_session trezor.wire.thp.thp_session

@ -1,18 +1,20 @@
from trezor import log, loop from trezor import log, loop
from trezor.messages import ThpCreateNewSession, ThpNewSession from trezor.messages import ThpCreateNewSession, ThpNewSession
from trezor.wire.thp import SessionState, channel from trezor.wire.thp import ChannelContext, SessionState
async def create_new_session( async def create_new_session(
channel: channel.Channel, message: ThpCreateNewSession channel: ChannelContext, message: ThpCreateNewSession
) -> ThpNewSession: ) -> ThpNewSession:
from trezor.wire.thp.session_context import SessionContext # from apps.common.seed import get_seed TODO
from trezor.wire.thp.session_manager import create_new_session
session = SessionContext.create_new_session(channel) session = create_new_session(channel)
session.set_session_state(SessionState.ALLOCATED) session.set_session_state(SessionState.ALLOCATED)
channel.sessions[session.session_id] = session channel.sessions[session.session_id] = session
loop.schedule(session.handle()) loop.schedule(session.handle())
new_session_id: int = session.session_id new_session_id: int = session.session_id
# await get_seed() TODO
if __debug__: if __debug__:
log.debug( log.debug(

@ -37,12 +37,12 @@ async def handle_pairing_request(
_check_state(ctx, ChannelState.TP1) _check_state(ctx, ChannelState.TP1)
if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
ctx.channel.set_channel_state(ChannelState.TP2) ctx.channel_ctx.set_channel_state(ChannelState.TP2)
response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
return await _handle_code_entry_challenge(ctx, response) return await _handle_code_entry_challenge(ctx, response)
ctx.channel.set_channel_state(ChannelState.TP3) ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await ctx.call_any( response = await ctx.call_any(
ThpPairingPreparationsFinished(), ThpPairingPreparationsFinished(),
MessageType.ThpQrCodeTag, MessageType.ThpQrCodeTag,
@ -63,7 +63,7 @@ async def _handle_code_entry_challenge(
assert ThpCodeEntryChallenge.is_type_of(message) assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(ctx, ChannelState.TP2) _check_state(ctx, ChannelState.TP2)
ctx.channel.set_channel_state(ChannelState.TP3) ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await ctx.call_any( response = await ctx.call_any(
ThpPairingPreparationsFinished(), ThpPairingPreparationsFinished(),
MessageType.ThpCodeEntryCpaceHost, MessageType.ThpCodeEntryCpaceHost,
@ -88,7 +88,7 @@ async def _handle_code_entry_cpace(
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
ctx.channel.set_channel_state(ChannelState.TP4) ctx.channel_ctx.set_channel_state(ChannelState.TP4)
response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
return await _handle_code_entry_tag(ctx, response) return await _handle_code_entry_tag(ctx, response)
@ -149,7 +149,7 @@ async def _handle_end_request(
assert ThpEndRequest.is_type_of(message) assert ThpEndRequest.is_type_of(message)
_check_state(ctx, ChannelState.TC1) _check_state(ctx, ChannelState.TC1)
ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse() return ThpEndResponse()
@ -161,7 +161,7 @@ async def _handle_tag_message(
) -> ThpEndResponse: ) -> ThpEndResponse:
_check_state(ctx, expected_state) _check_state(ctx, expected_state)
_check_method_is_allowed(ctx, used_method) _check_method_is_allowed(ctx, used_method)
ctx.channel.set_channel_state(ChannelState.TC1) ctx.channel_ctx.set_channel_state(ChannelState.TC1)
response = await ctx.call_any( response = await ctx.call_any(
msg, msg,
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
@ -171,7 +171,7 @@ async def _handle_tag_message(
def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None: def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
if expected_state is not ctx.channel.get_channel_state(): if expected_state is not ctx.channel_ctx.get_channel_state():
raise UnexpectedMessage("Unexpected message") raise UnexpectedMessage("Unexpected message")
@ -181,7 +181,7 @@ def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> N
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel.selected_pairing_methods return method in ctx.channel_ctx.selected_pairing_methods
async def _handle_credential_request_or_end_request( async def _handle_credential_request_or_end_request(

@ -2,6 +2,13 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING: if TYPE_CHECKING:
from enum import IntEnum from enum import IntEnum
from trezorio import WireInterface
from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils
from trezor.enums import FailureType
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.session_context import SessionContext
else: else:
IntEnum = object IntEnum = object
@ -27,3 +34,49 @@ class WireInterfaceType(IntEnum):
MOCK = 0 MOCK = 0
USB = 1 USB = 1
BLE = 2 BLE = 2
class ChannelContext:
def __init__(self, iface: WireInterface, channel_cache: ChannelCache):
self.buffer: utils.BufferType
self.iface: WireInterface = iface
self.channel_id: bytes = channel_cache.channel_id
self.channel_cache: ChannelCache = channel_cache
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: ChannelState) -> None: ...
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False
if __debug__:
def state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

@ -1,101 +1,60 @@
import ustruct # pyright: ignore[reportMissingModuleSource] import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
import usb from storage.cache_thp import TAG_LENGTH, ChannelCache
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import log, loop, protobuf, utils, workflow from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType, MessageType from trezor.enums import FailureType
from trezor.messages import ( from trezor.wire.thp import interface_manager, received_message_handler
Failure,
ThpCreateNewSession, from . import (
ThpHandshakeCompletionReqNoisePayload, ChannelContext,
ChannelState,
checksum,
control_byte,
crypto,
memory_manager,
) )
from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto
from . import thp_session as THP from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
from .thp_messages import (
ACK_MESSAGE,
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
ERROR,
HANDSHAKE_COMP_REQ,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_REQ,
HANDSHAKE_INIT_RES,
InitHeader,
)
from .thp_session import ThpError from .thp_session import ThpError
from .writer import ( from .writer import (
CONT_DATA_OFFSET, CONT_DATA_OFFSET,
INIT_DATA_OFFSET, INIT_DATA_OFFSET,
REPORT_LENGTH, MESSAGE_TYPE_LENGTH,
write_payload_to_wire, write_payload_to_wire,
) )
if TYPE_CHECKING: if __debug__:
from trezorio import WireInterface # pyright:ignore[reportMissingImports] from . import state_to_str
_WIRE_INTERFACE_USB = b"\x01"
_MOCK_INTERFACE_HID = b"\x00"
if TYPE_CHECKING:
MESSAGE_TYPE_LENGTH = const(2) from trezorio import WireInterface # pyright: ignore[reportMissingImports]
MAX_PAYLOAD_LEN = const(60000)
class Channel(Context): class Channel(ChannelContext):
def __init__(self, channel_cache: ChannelCache) -> None: def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__: if __debug__:
log.debug(__name__, "channel initialization") log.debug(__name__, "channel initialization")
iface = _decode_iface(channel_cache.iface) iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache.channel_id) super().__init__(iface, channel_cache)
self.channel_cache = channel_cache self.channel_cache = channel_cache
self.buffer: utils.BufferType
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.is_cont_packet_expected: bool = False self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0 self.expected_payload_length: int = 0
self.bytes_read: int = 0 self.bytes_read: int = 0
self.selected_pairing_methods = []
from trezor.wire.thp.session_context import load_cached_sessions
self.connection_context = None
self.sessions = load_cached_sessions(self)
@classmethod
def create_new_channel(
cls, iface: WireInterface, buffer: utils.BufferType
) -> "Channel":
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
r = cls(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
# ACCESS TO CHANNEL_DATA # ACCESS TO CHANNEL_DATA
def get_channel_state(self) -> int: def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big") state = int.from_bytes(self.channel_cache.state, "big")
if __debug__: if __debug__:
log.debug(__name__, "get_channel_state: %s", _state_to_str(state)) log.debug(__name__, "get_channel_state: %s", state_to_str(state))
return state return state
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def set_channel_state(self, state: ChannelState) -> None: def set_channel_state(self, state: ChannelState) -> None:
if __debug__:
log.debug(__name__, "set_channel_state: %s", _state_to_str(state))
self.channel_cache.state = bytearray(state.to_bytes(1, "big")) self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
if __debug__:
log.debug(__name__, "set_channel_state: %s", state_to_str(state))
def set_buffer(self, buffer: utils.BufferType) -> None: def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer self.buffer = buffer
@ -115,7 +74,7 @@ class Channel(Context):
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
self._finish_message() self._finish_message()
await self._handle_completed_message() await received_message_handler.handle_received_message(self, self.buffer)
elif self.expected_payload_length + INIT_DATA_OFFSET > self.bytes_read: elif self.expected_payload_length + INIT_DATA_OFFSET > self.bytes_read:
self.is_cont_packet_expected = True self.is_cont_packet_expected = True
else: else:
@ -125,7 +84,7 @@ class Channel(Context):
async def _handle_received_packet(self, packet: utils.BufferType) -> None: async def _handle_received_packet(self, packet: utils.BufferType) -> None:
ctrl_byte = packet[0] ctrl_byte = packet[0]
if _is_ctrl_byte_continuation(ctrl_byte): if control_byte.is_continuation(ctrl_byte):
await self._handle_cont_packet(packet) await self._handle_cont_packet(packet)
else: else:
await self._handle_init_packet(packet) await self._handle_init_packet(packet)
@ -138,42 +97,21 @@ class Channel(Context):
packet_payload = packet[5:] packet_payload = packet[5:]
# If the channel does not "own" the buffer lock, decrypt first packet # If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed! # TODO do it only when needed!
if _is_ctrl_byte_encrypted_transport(ctrl_byte): if control_byte.is_encrypted_transport(ctrl_byte):
packet_payload = self._decrypt_single_packet_payload(packet_payload) packet_payload = self._decrypt_single_packet_payload(packet_payload)
self._select_buffer(packet_payload, payload_length) self.buffer = memory_manager.select_buffer(
self.get_channel_state(),
self.buffer,
packet_payload,
payload_length,
)
await self._buffer_packet_data(self.buffer, packet, 0) await self._buffer_packet_data(self.buffer, packet, 0)
if __debug__: if __debug__:
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length) log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer)) log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
def _select_buffer(
self, packet_payload: utils.BufferType, payload_length: int
) -> None:
state = self.get_channel_state()
if state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
self.buffer: utils.BufferType = _get_buffer_for_message(
payload_length, self.buffer
)
except Exception as e:
if __debug__:
log.exception(__name__, e)
async def _handle_cont_packet(self, packet: utils.BufferType) -> None: async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__: if __debug__:
log.debug(__name__, "handle_cont_packet") log.debug(__name__, "handle_cont_packet")
@ -181,299 +119,12 @@ class Channel(Context):
raise ThpError("Continuation packet is not expected, ignoring") raise ThpError("Continuation packet is not expected, ignoring")
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
async def _handle_completed_message(self) -> None:
if __debug__:
log.debug(__name__, "handle_completed_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
message_length = payload_length + INIT_DATA_OFFSET
self._check_checksum(message_length)
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
if __debug__:
log.debug(
__name__,
"handle_completed_message - sync bit of message: %d",
sync_bit,
)
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
await self._handle_ack(sync_bit)
return
if (
self._should_have_ctrl_byte_encrypted_transport()
and not _is_ctrl_byte_encrypted_transport(ctrl_byte)
):
self._todo_clear_buffer()
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache):
if __debug__:
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
await self._send_ack(sync_bit)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
await self._send_ack(sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
await self._handle_message_to_app_or_channel(
payload_length, message_length, ctrl_byte, sync_bit
)
if __debug__:
log.debug(__name__, "handle_completed_message - end")
async def _handle_ack(self, sync_bit: int):
if not ack_handler.is_ack_valid(self.channel_cache, sync_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
if self.waiting_for_ack_timeout is not None:
self.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, 'Closed "waiting for ack" task')
THP.sync_set_can_send_message(self.channel_cache, True)
if self.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await self.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in terminations of this function - any code after
# this await might not be executed
def _check_checksum(self, message_length: int):
if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
data=self.buffer[: message_length - CHECKSUM_LENGTH],
):
self._todo_clear_buffer()
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
async def _handle_message_to_app_or_channel(
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
state = self.get_channel_state()
if __debug__:
log.debug(__name__, "state: %s", _state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
await self._handle_state_ENCRYPTED_TRANSPORT(message_length)
return
if state is ChannelState.TH1:
await self._handle_state_TH1(
payload_length, message_length, ctrl_byte, sync_bit
)
return
if state is ChannelState.TH2:
await self._handle_state_TH2(message_length, ctrl_byte, sync_bit)
return
if is_channel_state_pairing(state):
await self._handle_pairing(message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not _is_ctrl_byte_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
# send handshake init response message
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
)
)
self.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(
self, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not _is_ctrl_byte_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = self.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
noise_payload = thp_messages.decode_message(
self.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
for i in noise_payload.pairing_methods:
self.selected_pairing_methods.append(i)
if __debug__:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# TODO add credential recognition
paired: bool = False # TODO should be output from credential check
# send hanshake completion response
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
HANDSHAKE_COMP_RES,
thp_messages.get_handshake_completion_response(paired=paired),
)
)
if paired:
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
self.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
self._decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
await self._handle_channel_message(message_length, message_type)
return
if session_id not in self.sessions:
await self.write_error(
FailureType.ThpUnallocatedSession, "Unallocated session"
)
raise ThpError("Unalloacted session")
session_state = self.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
await self.write_error(
FailureType.ThpUnallocatedSession, "Unallocated session"
)
raise ThpError("Unalloacted session")
self.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
self.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(self, message_length: int) -> None:
from .pairing_context import PairingContext
if self.connection_context is None:
self.connection_context = PairingContext(self)
loop.schedule(self.connection_context.handle())
self._decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", self.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :]
)[0]
self.connection_context.incoming_message.publish(
MessageWithType(
message_type,
self.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
# 1. Check that message is expected with respect to the current state
# 2. Handle the message
pass
def _should_have_ctrl_byte_encrypted_transport(self) -> bool:
if self.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
async def _handle_channel_message(
self, message_length: int, message_type: int
) -> None:
buf = self.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
if not ThpCreateNewSession.is_type_of(message):
raise ThpError(
"This message cannot be handled by channel itself. It must be send to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
handler = get_handler_for_channel_message(message)
task = handler(self, message)
response_message = await task
# TODO handle
await self.write(response_message)
if __debug__:
log.debug(__name__, "_handle_channel_message - end")
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
payload_buffer = bytearray(payload) payload_buffer = bytearray(payload)
crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload_buffer return payload_buffer
def _decrypt_buffer(self, message_length: int) -> None: def decrypt_buffer(self, message_length: int) -> None:
if not isinstance(self.buffer, bytearray): if not isinstance(self.buffer, bytearray):
self.buffer = bytearray(self.buffer) self.buffer = bytearray(self.buffer)
crypto.decrypt( crypto.decrypt(
@ -511,38 +162,22 @@ class Channel(Context):
self.expected_payload_length = 0 self.expected_payload_length = 0
self.is_cont_packet_expected = False self.is_cont_packet_expected = False
async def _send_ack(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, self.get_channel_id_int(), CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes())
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.get_channel_id_int(),
ack_bit,
)
await write_payload_to_wire(self.iface, header, chksum)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
# CALLED BY WORKFLOW / SESSION CONTEXT # CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
if __debug__: if __debug__:
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME) log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
noise_payload_len = self._encode_into_buffer(msg, session_id) noise_payload_len = memory_manager.encode_into_buffer(
memoryview(self.buffer), msg, session_id
)
await self.write_and_encrypt(self.buffer[:noise_payload_len]) await self.write_and_encrypt(self.buffer[:noise_payload_len])
async def write_error(self, err_type: FailureType, message: str) -> None: async def write_error(self, err_type: FailureType, message: str) -> None:
if __debug__: if __debug__:
log.debug(__name__, "write_error") log.debug(__name__, "write_error")
msg_size = self._encode_error_into_buffer(err_type, message) msg_size = memory_manager.encode_error_into_buffer(
memoryview(self.buffer), err_type, message
)
data_length = MESSAGE_TYPE_LENGTH + msg_size data_length = MESSAGE_TYPE_LENGTH + msg_size
header: InitHeader = InitHeader( header: InitHeader = InitHeader(
ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH
@ -574,6 +209,12 @@ class Channel(Context):
) )
) )
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _prepare_write(self) -> None: def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false # TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False) THP.sync_set_can_send_message(self.channel_cache, False)
@ -585,7 +226,7 @@ class Channel(Context):
log.debug(__name__, "write_encrypted_payload_loop") log.debug(__name__, "write_encrypted_payload_loop")
payload_len = len(payload) + CHECKSUM_LENGTH payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = THP.sync_get_send_bit(self.channel_cache) sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit) ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit)
header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len) header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
chksum = checksum.compute(header.to_bytes() + payload) chksum = checksum.compute(header.to_bytes() + payload)
payload = payload + chksum payload = payload + chksum
@ -627,160 +268,3 @@ class Channel(Context):
async def _wait_for_ack(self) -> None: async def _wait_for_ack(self) -> None:
await loop.sleep(1000) await loop.sleep(1000)
def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(self.buffer):
# message is too big, we need to allocate a new buffer
self.buffer = bytearray(required_min_size)
buffer = self.buffer
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def _encode_error_into_buffer(self, err_code: FailureType, message: str) -> int:
error_message: protobuf.MessageType = Failure(code=err_code, message=message)
_encode_message_type_into_buffer(memoryview(self.buffer), MessageType.Failure)
_encode_message_into_buffer(
memoryview(self.buffer), error_message, MESSAGE_TYPE_LENGTH
)
return protobuf.encoded_length(error_message)
def _todo_clear_buffer(self):
# TODO Buffer clearing not implemented
pass
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels
def _decode_iface(cached_iface: bytes) -> WireInterface:
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
raise NotImplementedError("Should return MockHID WireInterface")
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def _encode_iface(iface: WireInterface) -> bytes:
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
if __debug__:
return _MOCK_INTERFACE_HID
raise Exception("Unknown WireInterface")
def _get_buffer_for_message(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
length = payload_length + INIT_DATA_OFFSET
if __debug__:
log.debug(
__name__,
"get_buffer_for_message - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(length)
except MemoryError:
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:length]
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ
def _is_ctrl_byte_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

@ -0,0 +1,30 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager
from .channel import Channel
if TYPE_CHECKING:
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> "Channel":
channel_cache = cache_thp.get_new_unauthenticated_channel(
interface_manager.encode_iface(iface)
)
r = Channel(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels

@ -0,0 +1,36 @@
from trezor.wire.thp.thp_messages import (
ACK_MESSAGE,
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
HANDSHAKE_COMP_REQ,
HANDSHAKE_INIT_REQ,
)
from trezor.wire.thp.thp_session import ThpError
def add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ

@ -0,0 +1,32 @@
from typing import TYPE_CHECKING
import usb
_MOCK_INTERFACE_HID = b"\x00"
_WIRE_INTERFACE_USB = b"\x01"
if TYPE_CHECKING:
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
def decode_iface(cached_iface: bytes) -> WireInterface:
"""Decode the cached wire interface."""
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
raise NotImplementedError("Should return MockHID WireInterface")
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def encode_iface(iface: WireInterface) -> bytes:
"""Encode wire interface into bytes."""
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
if __debug__:
return _MOCK_INTERFACE_HID
raise Exception("Unknown WireInterface")

@ -0,0 +1,128 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, protobuf, utils
from trezor.enums import FailureType, MessageType
from trezor.messages import Failure
from . import ChannelState
from .checksum import CHECKSUM_LENGTH
from .thp_session import ThpError
from .writer import (
INIT_DATA_OFFSET,
MAX_PAYLOAD_LEN,
MESSAGE_TYPE_LENGTH,
REPORT_LENGTH,
)
def select_buffer(
channel_state: int,
channel_buffer: utils.BufferType,
packet_payload: utils.BufferType,
payload_length: int,
) -> utils.BufferType:
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
buffer: utils.BufferType = _get_buffer_for_message(
payload_length, channel_buffer
)
return buffer
except Exception as e:
if __debug__:
log.exception(__name__, e)
raise Exception("Failed to create a buffer for channel") # TODO handle better
def encode_into_buffer(
buffer: memoryview, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer
buffer = memoryview(bytearray(required_min_size))
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def encode_error_into_buffer(
buffer: memoryview, err_code: FailureType, message: str
) -> int:
error_message: protobuf.MessageType = Failure(code=err_code, message=message)
_encode_message_type_into_buffer(buffer, MessageType.Failure)
_encode_message_into_buffer(buffer, error_message, MESSAGE_TYPE_LENGTH)
return protobuf.encoded_length(error_message)
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _get_buffer_for_message(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
length = payload_length + INIT_DATA_OFFSET
if __debug__:
log.debug(
__name__,
"get_buffer_for_message - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(length)
except MemoryError:
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:length]

@ -5,9 +5,9 @@ from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import Context, MessageWithType from trezor.wire.protocol_common import Context, MessageWithType
from trezor.wire.thp.session_context import UnexpectedMessageWithType
from .channel import Channel from . import ChannelContext
from .session_context import UnexpectedMessageWithType
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Container # pyright:ignore[reportShadowedImports] from typing import Container # pyright:ignore[reportShadowedImports]
@ -16,9 +16,9 @@ if TYPE_CHECKING:
class PairingContext(Context): class PairingContext(Context):
def __init__(self, channel: Channel) -> None: def __init__(self, channel_ctx: ChannelContext) -> None:
super().__init__(channel.iface, channel.channel_id) super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel = channel self.channel_ctx = channel_ctx
self.incoming_message = loop.chan() self.incoming_message = loop.chan()
async def handle(self, is_debug_session: bool = False) -> None: async def handle(self, is_debug_session: bool = False) -> None:
@ -104,7 +104,7 @@ class PairingContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type) return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None: async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg) return await self.channel_ctx.write(msg)
async def call( async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
@ -125,7 +125,9 @@ class PairingContext(Context):
async def handle_pairing_request_message( async def handle_pairing_request_message(
ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool pairing_ctx: PairingContext,
msg: protocol_common.MessageWithType,
use_workflow: bool,
) -> protocol_common.MessageWithType | None: ) -> protocol_common.MessageWithType | None:
res_msg: protobuf.MessageType | None = None res_msg: protobuf.MessageType | None = None
@ -147,7 +149,7 @@ async def handle_pairing_request_message(
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task. # Create the handler task.
task = handle_pairing_request(ctx, req_msg) task = handle_pairing_request(pairing_ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire # Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a # communication inside, but it should eventually return a
@ -156,7 +158,7 @@ async def handle_pairing_request_message(
if use_workflow: if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent # Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down. # workflows are shut down.
res_msg = await workflow.spawn(context.with_context(ctx, task)) res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
pass # TODO pass # TODO
else: else:
# For debug messages, ignore workflow processing and just await # For debug messages, ignore workflow processing and just await
@ -193,5 +195,5 @@ async def handle_pairing_request_message(
if res_msg is not None: if res_msg is not None:
# perform the write outside the big try-except block, so that usb write # perform the write outside the big try-except block, so that usb write
# problem bubbles up # problem bubbles up
await ctx.write(res_msg) await pairing_ctx.write(res_msg)
return None return None

@ -0,0 +1,349 @@
import ustruct # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType
from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler
from trezor.wire.protocol_common import MessageWithType
from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
from trezor.wire.thp.crypto import PUBKEY_LENGTH
from trezor.wire.thp.thp_messages import (
ACK_MESSAGE,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_RES,
InitHeader,
)
from . import (
ChannelContext,
ChannelState,
SessionState,
checksum,
control_byte,
is_channel_state_pairing,
)
from . import thp_session as THP
from .thp_session import ThpError
from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire
if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
if __debug__:
from . import state_to_str
async def handle_received_message(
ctx: ChannelContext, message_buffer: utils.BufferType
) -> None:
"""Handle a message received from the channel."""
if __debug__:
log.debug(__name__, "handle_received_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
message_length = payload_length + INIT_DATA_OFFSET
_check_checksum(message_length, message_buffer)
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
if __debug__:
log.debug(
__name__,
"handle_completed_message - sync bit of message: %d",
sync_bit,
)
# 1: Handle ACKs
if control_byte.is_ack(ctrl_byte):
await _handle_ack(ctx, sync_bit)
return
if _should_have_ctrl_byte_encrypted_transport(
ctx
) and not control_byte.is_encrypted_transport(ctrl_byte):
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(ctx.channel_cache):
if __debug__:
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
await _send_ack(ctx, sync_bit)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
await _send_ack(ctx, sync_bit)
THP.sync_set_receive_expected_bit(ctx.channel_cache, 1 - sync_bit)
await _handle_message_to_app_or_channel(
ctx, payload_length, message_length, ctrl_byte, sync_bit
)
if __debug__:
log.debug(__name__, "handle_received_message - end")
async def _send_ack(ctx: ChannelContext, ack_bit: int) -> None:
ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes())
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
ctx.get_channel_id_int(),
ack_bit,
)
await write_payload_to_wire(ctx.iface, header, chksum)
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
data=message_buffer[: message_length - CHECKSUM_LENGTH],
):
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
# TEST THIS
async def _handle_ack(ctx: ChannelContext, sync_bit: int):
if not ack_handler.is_ack_valid(ctx.channel_cache, sync_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
if ctx.waiting_for_ack_timeout is not None:
ctx.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, 'Closed "waiting for ack" task')
THP.sync_set_can_send_message(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in terminations of this function - any code after
# this await might not be executed
async def _handle_message_to_app_or_channel(
ctx: ChannelContext,
payload_length: int,
message_length: int,
ctrl_byte: int,
sync_bit: int,
) -> None:
state = ctx.get_channel_state()
if __debug__:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
return
if state is ChannelState.TH1:
await _handle_state_TH1(
ctx, payload_length, message_length, ctrl_byte, sync_bit
)
return
if state is ChannelState.TH2:
await _handle_state_TH2(ctx, message_length, ctrl_byte, sync_bit)
return
if is_channel_state_pairing(state):
await _handle_pairing(ctx, message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
ctx: ChannelContext,
payload_length: int,
message_length: int,
ctrl_byte: int,
sync_bit: int,
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not control_byte.is_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
ctx.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(ctx.channel_cache, host_ephemeral_key)
# send handshake init response message
await ctx.write_handshake_message(
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
)
ctx.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(
ctx: ChannelContext, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not control_byte.is_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
host_encrypted_static_pubkey = ctx.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = ctx.buffer[
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
]
noise_payload = thp_messages.decode_message(
ctx.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
for i in noise_payload.pairing_methods:
ctx.selected_pairing_methods.append(i)
if __debug__:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# TODO add credential recognition
paired: bool = True # TODO should be output from credential check
# send hanshake completion response
await ctx.write_handshake_message(
HANDSHAKE_COMP_RES,
thp_messages.get_handshake_completion_response(paired),
)
if paired:
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
ctx.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(
ctx: ChannelContext, message_length: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", ctx.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
await _handle_channel_message(ctx, message_length, message_type)
return
if session_id not in ctx.sessions:
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
raise ThpError("Unalloacted session")
session_state = ctx.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
raise ThpError("Unalloacted session")
ctx.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
ctx.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(ctx: ChannelContext, message_length: int) -> None:
from .pairing_context import PairingContext
if ctx.connection_context is None:
ctx.connection_context = PairingContext(ctx)
loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", ctx.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.publish(
MessageWithType(
message_type,
ctx.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
# 1. Check that message is expected with respect to the current state
# 2. Handle the message
pass
def _should_have_ctrl_byte_encrypted_transport(ctx: ChannelContext) -> bool:
if ctx.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
async def _handle_channel_message(
ctx: ChannelContext, message_length: int, message_type: int
) -> None:
buf = ctx.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
if not ThpCreateNewSession.is_type_of(message):
raise ThpError(
"The received message cannot be handled by channel itself. It must be sent to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
handler = get_handler_for_channel_message(message)
task = handler(ctx, message)
response_message = await task
# TODO handle
await ctx.write(response_message)
if __debug__:
log.debug(__name__, "_handle_channel_message - end")

@ -1,20 +1,25 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage import cache_thp
from storage.cache_thp import SessionThpCache from storage.cache_thp import SessionThpCache
from trezor import log, loop, protobuf from trezor import log, loop, protobuf
from trezor.wire import message_handler, protocol_common from trezor.wire import message_handler, protocol_common
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
from ..protocol_common import Context, MessageWithType from ..protocol_common import Context, MessageWithType
from . import SessionState from . import ChannelContext, SessionState
from .channel import Channel
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Container # pyright: ignore[reportShadowedImports] from typing import ( # pyright: ignore[reportShadowedImports]
Any,
Awaitable,
Container,
)
pass pass
_EXIT_LOOP = True
_REPEAT_LOOP = False
class UnexpectedMessageWithType(Exception): class UnexpectedMessageWithType(Exception):
"""A message was received that is not part of the current workflow. """A message was received that is not part of the current workflow.
@ -29,29 +34,22 @@ class UnexpectedMessageWithType(Exception):
class SessionContext(Context): class SessionContext(Context):
def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None: def __init__(
if channel.channel_id != session_cache.channel_id: self, channel_ctx: ChannelContext, session_cache: SessionThpCache
) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception( raise Exception(
"The session has different channel id than the provided channel context!" "The session has different channel id than the provided channel context!"
) )
super().__init__(channel.iface, channel.channel_id) super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel = channel self.channel_ctx = channel_ctx
self.session_cache = session_cache self.session_cache = session_cache
self.session_id = int.from_bytes(session_cache.session_id, "big") self.session_id = int.from_bytes(session_cache.session_id, "big")
self.incoming_message = loop.chan() self.incoming_message = loop.chan()
@classmethod
def create_new_session(cls, channel_context: Channel) -> "SessionContext":
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache)
async def handle(self, is_debug_session: bool = False) -> None: async def handle(self, is_debug_session: bool = False) -> None:
if __debug__: if __debug__:
log.debug(__name__, "handle - start (session_id: %d)", self.session_id) self._handle_debug(is_debug_session)
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take() take = self.incoming_message.take()
next_message: MessageWithType | None = None next_message: MessageWithType | None = None
@ -61,51 +59,70 @@ class SessionContext(Context):
# TODO modules = utils.unimport_begin() # TODO modules = utils.unimport_begin()
while True: while True:
try: try:
if next_message is None: if await self._handle_message(take, next_message, is_debug_session):
# If the previous run did not keep an unprocessed message for us, return
# wait for a new one.
try:
message: MessageWithType = await take
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await message_handler.handle_single_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if (
next_message is None
and message.type not in AVOID_RESTARTING_FOR
):
# Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception
except Exception as exc: except Exception as exc:
# Log and try again. The session handler can only exit explicitly via # Log and try again.
# loop.clear() above.
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
def _handle_debug(self, is_debug_session: bool) -> None:
log.debug(__name__, "handle - start (session_id: %d)", self.session_id)
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
async def _handle_message(
self,
take: Awaitable[Any],
next_message: MessageWithType | None,
is_debug_session: bool,
) -> bool:
try:
message = await self._get_message(take, next_message)
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
return _REPEAT_LOOP
try:
next_message = await message_handler.handle_single_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None and message.type not in AVOID_RESTARTING_FOR:
# Shut down the loop if there is no next message waiting.
return _EXIT_LOOP # pylint: disable=lost-exception
return _REPEAT_LOOP # pylint: disable=lost-exception
async def _get_message(
self, take: Awaitable[Any], next_message: MessageWithType | None
) -> MessageWithType:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
message: MessageWithType = await take
else:
# Process the message from previous run.
message = next_message
next_message = None
return message
async def read( async def read(
self, self,
expected_types: Container[int], expected_types: Container[int],
@ -131,7 +148,7 @@ class SessionContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type) return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None: async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id) return await self.channel_ctx.write(msg, self.session_id)
# ACCESS TO SESSION DATA # ACCESS TO SESSION DATA
@ -141,22 +158,3 @@ class SessionContext(Context):
def set_session_state(self, state: SessionState) -> None: def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.to_bytes(1, "big")) self.session_cache.state = bytearray(state.to_bytes(1, "big"))
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
if __debug__:
log.debug(
__name__,
"load_cached_sessions - loaded a total of %d sessions from cache",
len(cached_sessions),
)
for session in cached_sessions:
if session.channel_id == channel.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel, session)
loop.schedule(sessions[sid].handle())
return sessions

@ -0,0 +1,28 @@
from storage import cache_thp
from trezor import log, loop
from trezor.wire.thp import ChannelContext
from trezor.wire.thp.session_context import SessionContext
def create_new_session(channel_ctx: ChannelContext) -> SessionContext:
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
return SessionContext(channel_ctx, session_cache)
def load_cached_sessions(channel_ctx: ChannelContext) -> dict[int, SessionContext]:
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
if __debug__:
log.debug(
__name__,
"load_cached_sessions - loaded a total of %d sessions from cache",
len(cached_sessions),
)
for session in cached_sessions:
if session.channel_id == channel_ctx.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel_ctx, session)
loop.schedule(sessions[sid].handle())
return sessions

@ -7,6 +7,8 @@ from trezor.wire.thp.thp_messages import InitHeader
INIT_DATA_OFFSET = const(5) INIT_DATA_OFFSET = const(5)
CONT_DATA_OFFSET = const(3) CONT_DATA_OFFSET = const(3)
REPORT_LENGTH = const(64) REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
MESSAGE_TYPE_LENGTH = const(2)
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports] from trezorio import WireInterface # pyright: ignore[reportMissingImports]

@ -6,12 +6,12 @@ from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io, log, loop, utils from trezor import io, log, loop, utils
from .protocol_common import MessageWithId from .protocol_common import MessageWithId
from .thp import ChannelState, checksum, thp_messages from .thp import ChannelState, channel_manager, checksum, session_manager, thp_messages
from .thp.channel import MAX_PAYLOAD_LEN, REPORT_LENGTH, Channel, load_cached_channels from .thp.channel import Channel
from .thp.checksum import CHECKSUM_LENGTH from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader
from .thp.thp_session import ThpError from .thp.thp_session import ThpError
from .thp.writer import write_payload_to_wire from .thp.writer import MAX_PAYLOAD_LEN, REPORT_LENGTH, write_payload_to_wire
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports] from trezorio import WireInterface # pyright: ignore[reportMissingImports]
@ -33,7 +33,9 @@ def set_buffer(buffer):
async def thp_main_loop(iface: WireInterface, is_debug_session=False): async def thp_main_loop(iface: WireInterface, is_debug_session=False):
global CHANNELS global CHANNELS
global _BUFFER global _BUFFER
CHANNELS = load_cached_channels(_BUFFER) CHANNELS = channel_manager.load_cached_channels(_BUFFER)
for ch in CHANNELS.values():
ch.sessions = session_manager.load_cached_sessions(ch)
read = loop.wait(iface.iface_num() | io.POLL_READ) read = loop.wait(iface.iface_num() | io.POLL_READ)
@ -55,18 +57,9 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
continue continue
if cid in CHANNELS: if cid in CHANNELS:
channel = CHANNELS[cid] await _handle_allocated(iface, cid, packet)
if channel is None: else:
# TODO send error message to wire await _handle_unallocated(iface, cid)
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
await channel.receive_packet(packet)
continue
await _handle_unallocated(iface, cid)
except ThpError as e: except ThpError as e:
if __debug__: if __debug__:
@ -76,7 +69,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
async def _handle_broadcast( async def _handle_broadcast(
iface: WireInterface, ctrl_byte, packet iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> MessageWithId | None: ) -> MessageWithId | None:
global _BUFFER global _BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ: if ctrl_byte != CHANNEL_ALLOCATION_REQ:
@ -91,7 +84,7 @@ async def _handle_broadcast(
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid") raise ThpError("Checksum is not valid")
new_channel: Channel = Channel.create_new_channel(iface, _BUFFER) new_channel: Channel = channel_manager.create_new_channel(iface, _BUFFER)
cid = int.from_bytes(new_channel.channel_id, "big") cid = int.from_bytes(new_channel.channel_id, "big")
CHANNELS[cid] = new_channel CHANNELS[cid] = new_channel
@ -108,6 +101,21 @@ async def _handle_broadcast(
await write_payload_to_wire(iface, response_header, response_data + chksum) await write_payload_to_wire(iface, response_header, response_data + chksum)
async def _handle_allocated(
iface: WireInterface, cid: int, packet: utils.BufferType
) -> None:
channel = CHANNELS[cid]
if channel is None:
# TODO send error message to wire
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
await channel.receive_packet(packet)
async def _handle_unallocated(iface, cid) -> MessageWithId | None: async def _handle_unallocated(iface, cid) -> MessageWithId | None:
data = thp_messages.get_error_unallocated_channel() data = thp_messages.get_error_unallocated_channel()
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)

Loading…
Cancel
Save