Channel refactor

M1nd3r/thp5
M1nd3r 3 weeks ago
parent 531176570d
commit ea6cb32e59

@ -213,16 +213,30 @@ trezor.wire.thp.ack_handler
import trezor.wire.thp.ack_handler
trezor.wire.thp.channel
import trezor.wire.thp.channel
trezor.wire.thp.channel_manager
import trezor.wire.thp.channel_manager
trezor.wire.thp.checksum
import trezor.wire.thp.checksum
trezor.wire.thp.control_byte
import trezor.wire.thp.control_byte
trezor.wire.thp.crypto
import trezor.wire.thp.crypto
trezor.wire.thp.handler_provider
import trezor.wire.thp.handler_provider
trezor.wire.thp.interface_manager
import trezor.wire.thp.interface_manager
trezor.wire.thp.memory_manager
import trezor.wire.thp.memory_manager
trezor.wire.thp.pairing_context
import trezor.wire.thp.pairing_context
trezor.wire.thp.received_message_handler
import trezor.wire.thp.received_message_handler
trezor.wire.thp.retransmission
import trezor.wire.thp.retransmission
trezor.wire.thp.session_context
import trezor.wire.thp.session_context
trezor.wire.thp.session_manager
import trezor.wire.thp.session_manager
trezor.wire.thp.thp_messages
import trezor.wire.thp.thp_messages
trezor.wire.thp.thp_session

@ -1,18 +1,20 @@
from trezor import log, loop
from trezor.messages import ThpCreateNewSession, ThpNewSession
from trezor.wire.thp import SessionState, channel
from trezor.wire.thp import ChannelContext, SessionState
async def create_new_session(
channel: channel.Channel, message: ThpCreateNewSession
channel: ChannelContext, message: ThpCreateNewSession
) -> ThpNewSession:
from trezor.wire.thp.session_context import SessionContext
# from apps.common.seed import get_seed TODO
from trezor.wire.thp.session_manager import create_new_session
session = SessionContext.create_new_session(channel)
session = create_new_session(channel)
session.set_session_state(SessionState.ALLOCATED)
channel.sessions[session.session_id] = session
loop.schedule(session.handle())
new_session_id: int = session.session_id
# await get_seed() TODO
if __debug__:
log.debug(

@ -37,12 +37,12 @@ async def handle_pairing_request(
_check_state(ctx, ChannelState.TP1)
if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
ctx.channel.set_channel_state(ChannelState.TP2)
ctx.channel_ctx.set_channel_state(ChannelState.TP2)
response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
return await _handle_code_entry_challenge(ctx, response)
ctx.channel.set_channel_state(ChannelState.TP3)
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await ctx.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpQrCodeTag,
@ -63,7 +63,7 @@ async def _handle_code_entry_challenge(
assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(ctx, ChannelState.TP2)
ctx.channel.set_channel_state(ChannelState.TP3)
ctx.channel_ctx.set_channel_state(ChannelState.TP3)
response = await ctx.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpCodeEntryCpaceHost,
@ -88,7 +88,7 @@ async def _handle_code_entry_cpace(
_check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
ctx.channel.set_channel_state(ChannelState.TP4)
ctx.channel_ctx.set_channel_state(ChannelState.TP4)
response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
return await _handle_code_entry_tag(ctx, response)
@ -149,7 +149,7 @@ async def _handle_end_request(
assert ThpEndRequest.is_type_of(message)
_check_state(ctx, ChannelState.TC1)
ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
ctx.channel_ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
@ -161,7 +161,7 @@ async def _handle_tag_message(
) -> ThpEndResponse:
_check_state(ctx, expected_state)
_check_method_is_allowed(ctx, used_method)
ctx.channel.set_channel_state(ChannelState.TC1)
ctx.channel_ctx.set_channel_state(ChannelState.TC1)
response = await ctx.call_any(
msg,
MessageType.ThpCredentialRequest,
@ -171,7 +171,7 @@ async def _handle_tag_message(
def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
if expected_state is not ctx.channel.get_channel_state():
if expected_state is not ctx.channel_ctx.get_channel_state():
raise UnexpectedMessage("Unexpected message")
@ -181,7 +181,7 @@ def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> N
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel.selected_pairing_methods
return method in ctx.channel_ctx.selected_pairing_methods
async def _handle_credential_request_or_end_request(

@ -2,6 +2,13 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING:
from enum import IntEnum
from trezorio import WireInterface
from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils
from trezor.enums import FailureType
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.session_context import SessionContext
else:
IntEnum = object
@ -27,3 +34,49 @@ class WireInterfaceType(IntEnum):
MOCK = 0
USB = 1
BLE = 2
class ChannelContext:
def __init__(self, iface: WireInterface, channel_cache: ChannelCache):
self.buffer: utils.BufferType
self.iface: WireInterface = iface
self.channel_id: bytes = channel_cache.channel_id
self.channel_cache: ChannelCache = channel_cache
self.selected_pairing_methods = []
self.sessions: dict[int, SessionContext] = {}
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.connection_context: PairingContext | None = None
def get_channel_state(self) -> int: ...
def set_channel_state(self, state: ChannelState) -> None: ...
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ...
async def write_error(self, err_type: FailureType, message: str) -> None: ...
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ...
def decrypt_buffer(self, message_length: int) -> None: ...
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False
if __debug__:
def state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

@ -1,101 +1,60 @@
import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
import usb
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from storage.cache_thp import TAG_LENGTH, ChannelCache
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType, MessageType
from trezor.messages import (
Failure,
ThpCreateNewSession,
ThpHandshakeCompletionReqNoisePayload,
from trezor.enums import FailureType
from trezor.wire.thp import interface_manager, received_message_handler
from . import (
ChannelContext,
ChannelState,
checksum,
control_byte,
crypto,
memory_manager,
)
from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH
from .thp_messages import (
ACK_MESSAGE,
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
ERROR,
HANDSHAKE_COMP_REQ,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_REQ,
HANDSHAKE_INIT_RES,
InitHeader,
)
from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader
from .thp_session import ThpError
from .writer import (
CONT_DATA_OFFSET,
INIT_DATA_OFFSET,
REPORT_LENGTH,
MESSAGE_TYPE_LENGTH,
write_payload_to_wire,
)
if TYPE_CHECKING:
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
_WIRE_INTERFACE_USB = b"\x01"
_MOCK_INTERFACE_HID = b"\x00"
if __debug__:
from . import state_to_str
MESSAGE_TYPE_LENGTH = const(2)
MAX_PAYLOAD_LEN = const(60000)
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
class Channel(Context):
class Channel(ChannelContext):
def __init__(self, channel_cache: ChannelCache) -> None:
if __debug__:
log.debug(__name__, "channel initialization")
iface = _decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache.channel_id)
iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache)
self.channel_cache = channel_cache
self.buffer: utils.BufferType
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
self.selected_pairing_methods = []
from trezor.wire.thp.session_context import load_cached_sessions
self.connection_context = None
self.sessions = load_cached_sessions(self)
@classmethod
def create_new_channel(
cls, iface: WireInterface, buffer: utils.BufferType
) -> "Channel":
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
r = cls(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
# ACCESS TO CHANNEL_DATA
def get_channel_state(self) -> int:
state = int.from_bytes(self.channel_cache.state, "big")
if __debug__:
log.debug(__name__, "get_channel_state: %s", _state_to_str(state))
log.debug(__name__, "get_channel_state: %s", state_to_str(state))
return state
def get_channel_id_int(self) -> int:
return int.from_bytes(self.channel_id, "big")
def set_channel_state(self, state: ChannelState) -> None:
if __debug__:
log.debug(__name__, "set_channel_state: %s", _state_to_str(state))
self.channel_cache.state = bytearray(state.to_bytes(1, "big"))
if __debug__:
log.debug(__name__, "set_channel_state: %s", state_to_str(state))
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
@ -115,7 +74,7 @@ class Channel(Context):
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
self._finish_message()
await self._handle_completed_message()
await received_message_handler.handle_received_message(self, self.buffer)
elif self.expected_payload_length + INIT_DATA_OFFSET > self.bytes_read:
self.is_cont_packet_expected = True
else:
@ -125,7 +84,7 @@ class Channel(Context):
async def _handle_received_packet(self, packet: utils.BufferType) -> None:
ctrl_byte = packet[0]
if _is_ctrl_byte_continuation(ctrl_byte):
if control_byte.is_continuation(ctrl_byte):
await self._handle_cont_packet(packet)
else:
await self._handle_init_packet(packet)
@ -138,42 +97,21 @@ class Channel(Context):
packet_payload = packet[5:]
# If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed!
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
if control_byte.is_encrypted_transport(ctrl_byte):
packet_payload = self._decrypt_single_packet_payload(packet_payload)
self._select_buffer(packet_payload, payload_length)
self.buffer = memory_manager.select_buffer(
self.get_channel_state(),
self.buffer,
packet_payload,
payload_length,
)
await self._buffer_packet_data(self.buffer, packet, 0)
if __debug__:
log.debug(__name__, "handle_init_packet - payload len: %d", payload_length)
log.debug(__name__, "handle_init_packet - buffer len: %d", len(self.buffer))
def _select_buffer(
self, packet_payload: utils.BufferType, payload_length: int
) -> None:
state = self.get_channel_state()
if state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
self.buffer: utils.BufferType = _get_buffer_for_message(
payload_length, self.buffer
)
except Exception as e:
if __debug__:
log.exception(__name__, e)
async def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__:
log.debug(__name__, "handle_cont_packet")
@ -181,299 +119,12 @@ class Channel(Context):
raise ThpError("Continuation packet is not expected, ignoring")
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
async def _handle_completed_message(self) -> None:
if __debug__:
log.debug(__name__, "handle_completed_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
message_length = payload_length + INIT_DATA_OFFSET
self._check_checksum(message_length)
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
if __debug__:
log.debug(
__name__,
"handle_completed_message - sync bit of message: %d",
sync_bit,
)
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
await self._handle_ack(sync_bit)
return
if (
self._should_have_ctrl_byte_encrypted_transport()
and not _is_ctrl_byte_encrypted_transport(ctrl_byte)
):
self._todo_clear_buffer()
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache):
if __debug__:
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
await self._send_ack(sync_bit)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
await self._send_ack(sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
await self._handle_message_to_app_or_channel(
payload_length, message_length, ctrl_byte, sync_bit
)
if __debug__:
log.debug(__name__, "handle_completed_message - end")
async def _handle_ack(self, sync_bit: int):
if not ack_handler.is_ack_valid(self.channel_cache, sync_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
if self.waiting_for_ack_timeout is not None:
self.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, 'Closed "waiting for ack" task')
THP.sync_set_can_send_message(self.channel_cache, True)
if self.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await self.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in terminations of this function - any code after
# this await might not be executed
def _check_checksum(self, message_length: int):
if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
data=self.buffer[: message_length - CHECKSUM_LENGTH],
):
self._todo_clear_buffer()
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
async def _handle_message_to_app_or_channel(
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
state = self.get_channel_state()
if __debug__:
log.debug(__name__, "state: %s", _state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
await self._handle_state_ENCRYPTED_TRANSPORT(message_length)
return
if state is ChannelState.TH1:
await self._handle_state_TH1(
payload_length, message_length, ctrl_byte, sync_bit
)
return
if state is ChannelState.TH2:
await self._handle_state_TH2(message_length, ctrl_byte, sync_bit)
return
if is_channel_state_pairing(state):
await self._handle_pairing(message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not _is_ctrl_byte_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
# send handshake init response message
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
)
)
self.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(
self, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not _is_ctrl_byte_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = self.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
noise_payload = thp_messages.decode_message(
self.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
for i in noise_payload.pairing_methods:
self.selected_pairing_methods.append(i)
if __debug__:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# TODO add credential recognition
paired: bool = False # TODO should be output from credential check
# send hanshake completion response
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
HANDSHAKE_COMP_RES,
thp_messages.get_handshake_completion_response(paired=paired),
)
)
if paired:
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
self.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
self._decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
await self._handle_channel_message(message_length, message_type)
return
if session_id not in self.sessions:
await self.write_error(
FailureType.ThpUnallocatedSession, "Unallocated session"
)
raise ThpError("Unalloacted session")
session_state = self.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
await self.write_error(
FailureType.ThpUnallocatedSession, "Unallocated session"
)
raise ThpError("Unalloacted session")
self.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
self.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(self, message_length: int) -> None:
from .pairing_context import PairingContext
if self.connection_context is None:
self.connection_context = PairingContext(self)
loop.schedule(self.connection_context.handle())
self._decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", self.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :]
)[0]
self.connection_context.incoming_message.publish(
MessageWithType(
message_type,
self.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
# 1. Check that message is expected with respect to the current state
# 2. Handle the message
pass
def _should_have_ctrl_byte_encrypted_transport(self) -> bool:
if self.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
async def _handle_channel_message(
self, message_length: int, message_type: int
) -> None:
buf = self.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
if not ThpCreateNewSession.is_type_of(message):
raise ThpError(
"This message cannot be handled by channel itself. It must be send to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
handler = get_handler_for_channel_message(message)
task = handler(self, message)
response_message = await task
# TODO handle
await self.write(response_message)
if __debug__:
log.debug(__name__, "_handle_channel_message - end")
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
payload_buffer = bytearray(payload)
crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
return payload_buffer
def _decrypt_buffer(self, message_length: int) -> None:
def decrypt_buffer(self, message_length: int) -> None:
if not isinstance(self.buffer, bytearray):
self.buffer = bytearray(self.buffer)
crypto.decrypt(
@ -511,38 +162,22 @@ class Channel(Context):
self.expected_payload_length = 0
self.is_cont_packet_expected = False
async def _send_ack(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, self.get_channel_id_int(), CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes())
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.get_channel_id_int(),
ack_bit,
)
await write_payload_to_wire(self.iface, header, chksum)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
if __debug__:
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
noise_payload_len = self._encode_into_buffer(msg, session_id)
noise_payload_len = memory_manager.encode_into_buffer(
memoryview(self.buffer), msg, session_id
)
await self.write_and_encrypt(self.buffer[:noise_payload_len])
async def write_error(self, err_type: FailureType, message: str) -> None:
if __debug__:
log.debug(__name__, "write_error")
msg_size = self._encode_error_into_buffer(err_type, message)
msg_size = memory_manager.encode_error_into_buffer(
memoryview(self.buffer), err_type, message
)
data_length = MESSAGE_TYPE_LENGTH + msg_size
header: InitHeader = InitHeader(
ERROR, self.get_channel_id_int(), data_length + CHECKSUM_LENGTH
@ -574,6 +209,12 @@ class Channel(Context):
)
)
async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False)
@ -585,7 +226,7 @@ class Channel(Context):
log.debug(__name__, "write_encrypted_payload_loop")
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit)
ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit)
header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
chksum = checksum.compute(header.to_bytes() + payload)
payload = payload + chksum
@ -627,160 +268,3 @@ class Channel(Context):
async def _wait_for_ack(self) -> None:
await loop.sleep(1000)
def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(self.buffer):
# message is too big, we need to allocate a new buffer
self.buffer = bytearray(required_min_size)
buffer = self.buffer
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def _encode_error_into_buffer(self, err_code: FailureType, message: str) -> int:
error_message: protobuf.MessageType = Failure(code=err_code, message=message)
_encode_message_type_into_buffer(memoryview(self.buffer), MessageType.Failure)
_encode_message_into_buffer(
memoryview(self.buffer), error_message, MESSAGE_TYPE_LENGTH
)
return protobuf.encoded_length(error_message)
def _todo_clear_buffer(self):
# TODO Buffer clearing not implemented
pass
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels
def _decode_iface(cached_iface: bytes) -> WireInterface:
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
raise NotImplementedError("Should return MockHID WireInterface")
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def _encode_iface(iface: WireInterface) -> bytes:
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
if __debug__:
return _MOCK_INTERFACE_HID
raise Exception("Unknown WireInterface")
def _get_buffer_for_message(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
length = payload_length + INIT_DATA_OFFSET
if __debug__:
log.debug(
__name__,
"get_buffer_for_message - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(length)
except MemoryError:
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:length]
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ
def _is_ctrl_byte_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def is_channel_state_pairing(state: int) -> bool:
if state in (
ChannelState.TP1,
ChannelState.TP2,
ChannelState.TP3,
ChannelState.TP4,
ChannelState.TC1,
):
return True
return False
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _state_to_str(state: int) -> str:
name = {
v: k for k, v in ChannelState.__dict__.items() if not k.startswith("__")
}.get(state)
if name is not None:
return name
return "UNKNOWN_STATE"

@ -0,0 +1,30 @@
from typing import TYPE_CHECKING
from storage import cache_thp
from trezor import utils
from . import ChannelState, interface_manager
from .channel import Channel
if TYPE_CHECKING:
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
def create_new_channel(iface: WireInterface, buffer: utils.BufferType) -> "Channel":
channel_cache = cache_thp.get_new_unauthenticated_channel(
interface_manager.encode_iface(iface)
)
r = Channel(channel_cache)
r.set_buffer(buffer)
r.set_channel_state(ChannelState.TH1)
return r
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values():
c.set_buffer(buffer)
return channels

@ -0,0 +1,36 @@
from trezor.wire.thp.thp_messages import (
ACK_MESSAGE,
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
HANDSHAKE_COMP_REQ,
HANDSHAKE_INIT_REQ,
)
from trezor.wire.thp.thp_session import ThpError
def add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & 0x80 == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ

@ -0,0 +1,32 @@
from typing import TYPE_CHECKING
import usb
_MOCK_INTERFACE_HID = b"\x00"
_WIRE_INTERFACE_USB = b"\x01"
if TYPE_CHECKING:
from trezorio import WireInterface # pyright:ignore[reportMissingImports]
def decode_iface(cached_iface: bytes) -> WireInterface:
"""Decode the cached wire interface."""
if cached_iface == _WIRE_INTERFACE_USB:
iface = usb.iface_wire
if iface is None:
raise RuntimeError("There is no valid USB WireInterface")
return iface
if __debug__ and cached_iface == _MOCK_INTERFACE_HID:
raise NotImplementedError("Should return MockHID WireInterface")
# TODO implement bluetooth interface
raise Exception("Unknown WireInterface")
def encode_iface(iface: WireInterface) -> bytes:
"""Encode wire interface into bytes."""
if iface is usb.iface_wire:
return _WIRE_INTERFACE_USB
# TODO implement bluetooth interface
if __debug__:
return _MOCK_INTERFACE_HID
raise Exception("Unknown WireInterface")

@ -0,0 +1,128 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, protobuf, utils
from trezor.enums import FailureType, MessageType
from trezor.messages import Failure
from . import ChannelState
from .checksum import CHECKSUM_LENGTH
from .thp_session import ThpError
from .writer import (
INIT_DATA_OFFSET,
MAX_PAYLOAD_LEN,
MESSAGE_TYPE_LENGTH,
REPORT_LENGTH,
)
def select_buffer(
channel_state: int,
channel_buffer: utils.BufferType,
packet_payload: utils.BufferType,
payload_length: int,
) -> utils.BufferType:
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
session_id = packet_payload[0]
if session_id == 0:
pass
# TODO use small buffer
else:
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return
else:
pass
# TODO use small buffer
try:
# TODO for now, we create a new big buffer every time. It should be changed
buffer: utils.BufferType = _get_buffer_for_message(
payload_length, channel_buffer
)
return buffer
except Exception as e:
if __debug__:
log.exception(__name__, e)
raise Exception("Failed to create a buffer for channel") # TODO handle better
def encode_into_buffer(
buffer: memoryview, msg: protobuf.MessageType, session_id: int
) -> int:
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer
buffer = memoryview(bytearray(required_min_size))
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(
memoryview(buffer), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH
)
_encode_message_into_buffer(
memoryview(buffer), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
)
return payload_size
def encode_error_into_buffer(
buffer: memoryview, err_code: FailureType, message: str
) -> int:
error_message: protobuf.MessageType = Failure(code=err_code, message=message)
_encode_message_type_into_buffer(buffer, MessageType.Failure)
_encode_message_into_buffer(buffer, error_message, MESSAGE_TYPE_LENGTH)
return protobuf.encoded_length(error_message)
def _encode_session_into_buffer(
buffer: memoryview, session_id: int, buffer_offset: int = 0
) -> None:
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
utils.memcpy(buffer, buffer_offset, session_id_bytes, 0)
def _encode_message_type_into_buffer(
buffer: memoryview, message_type: int, offset: int = 0
) -> None:
msg_type_bytes = int.to_bytes(message_type, MESSAGE_TYPE_LENGTH, "big")
utils.memcpy(buffer, offset, msg_type_bytes, 0)
def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _get_buffer_for_message(
payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN
) -> utils.BufferType:
length = payload_length + INIT_DATA_OFFSET
if __debug__:
log.debug(
__name__,
"get_buffer_for_message - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
# allocate a new buffer to fit the message
try:
payload: utils.BufferType = bytearray(length)
except MemoryError:
payload = bytearray(REPORT_LENGTH)
raise ThpError("Message too large")
return payload
# reuse a part of the supplied buffer
return memoryview(existing_buffer)[:length]

@ -5,9 +5,9 @@ from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import Context, MessageWithType
from trezor.wire.thp.session_context import UnexpectedMessageWithType
from .channel import Channel
from . import ChannelContext
from .session_context import UnexpectedMessageWithType
if TYPE_CHECKING:
from typing import Container # pyright:ignore[reportShadowedImports]
@ -16,9 +16,9 @@ if TYPE_CHECKING:
class PairingContext(Context):
def __init__(self, channel: Channel) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel = channel
def __init__(self, channel_ctx: ChannelContext) -> None:
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx = channel_ctx
self.incoming_message = loop.chan()
async def handle(self, is_debug_session: bool = False) -> None:
@ -104,7 +104,7 @@ class PairingContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg)
return await self.channel_ctx.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
@ -125,7 +125,9 @@ class PairingContext(Context):
async def handle_pairing_request_message(
ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool
pairing_ctx: PairingContext,
msg: protocol_common.MessageWithType,
use_workflow: bool,
) -> protocol_common.MessageWithType | None:
res_msg: protobuf.MessageType | None = None
@ -147,7 +149,7 @@ async def handle_pairing_request_message(
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handle_pairing_request(ctx, req_msg)
task = handle_pairing_request(pairing_ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
@ -156,7 +158,7 @@ async def handle_pairing_request_message(
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(context.with_context(ctx, task))
res_msg = await workflow.spawn(context.with_context(pairing_ctx, task))
pass # TODO
else:
# For debug messages, ignore workflow processing and just await
@ -193,5 +195,5 @@ async def handle_pairing_request_message(
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
await pairing_ctx.write(res_msg)
return None

@ -0,0 +1,349 @@
import ustruct # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType
from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler
from trezor.wire.protocol_common import MessageWithType
from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
from trezor.wire.thp.crypto import PUBKEY_LENGTH
from trezor.wire.thp.thp_messages import (
ACK_MESSAGE,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_RES,
InitHeader,
)
from . import (
ChannelContext,
ChannelState,
SessionState,
checksum,
control_byte,
is_channel_state_pairing,
)
from . import thp_session as THP
from .thp_session import ThpError
from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire
if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
if __debug__:
from . import state_to_str
async def handle_received_message(
ctx: ChannelContext, message_buffer: utils.BufferType
) -> None:
"""Handle a message received from the channel."""
if __debug__:
log.debug(__name__, "handle_received_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
message_length = payload_length + INIT_DATA_OFFSET
_check_checksum(message_length, message_buffer)
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
if __debug__:
log.debug(
__name__,
"handle_completed_message - sync bit of message: %d",
sync_bit,
)
# 1: Handle ACKs
if control_byte.is_ack(ctrl_byte):
await _handle_ack(ctx, sync_bit)
return
if _should_have_ctrl_byte_encrypted_transport(
ctx
) and not control_byte.is_encrypted_transport(ctrl_byte):
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(ctx.channel_cache):
if __debug__:
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
await _send_ack(ctx, sync_bit)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
await _send_ack(ctx, sync_bit)
THP.sync_set_receive_expected_bit(ctx.channel_cache, 1 - sync_bit)
await _handle_message_to_app_or_channel(
ctx, payload_length, message_length, ctrl_byte, sync_bit
)
if __debug__:
log.debug(__name__, "handle_received_message - end")
async def _send_ack(ctx: ChannelContext, ack_bit: int) -> None:
ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes())
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
ctx.get_channel_id_int(),
ack_bit,
)
await write_payload_to_wire(ctx.iface, header, chksum)
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
data=message_buffer[: message_length - CHECKSUM_LENGTH],
):
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
# TEST THIS
async def _handle_ack(ctx: ChannelContext, sync_bit: int):
if not ack_handler.is_ack_valid(ctx.channel_cache, sync_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
if ctx.waiting_for_ack_timeout is not None:
ctx.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, 'Closed "waiting for ack" task')
THP.sync_set_can_send_message(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in terminations of this function - any code after
# this await might not be executed
async def _handle_message_to_app_or_channel(
ctx: ChannelContext,
payload_length: int,
message_length: int,
ctrl_byte: int,
sync_bit: int,
) -> None:
state = ctx.get_channel_state()
if __debug__:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
return
if state is ChannelState.TH1:
await _handle_state_TH1(
ctx, payload_length, message_length, ctrl_byte, sync_bit
)
return
if state is ChannelState.TH2:
await _handle_state_TH2(ctx, message_length, ctrl_byte, sync_bit)
return
if is_channel_state_pairing(state):
await _handle_pairing(ctx, message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
ctx: ChannelContext,
payload_length: int,
message_length: int,
ctrl_byte: int,
sync_bit: int,
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not control_byte.is_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
ctx.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(ctx.channel_cache, host_ephemeral_key)
# send handshake init response message
await ctx.write_handshake_message(
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
)
ctx.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(
ctx: ChannelContext, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not control_byte.is_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
host_encrypted_static_pubkey = ctx.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = ctx.buffer[
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
]
noise_payload = thp_messages.decode_message(
ctx.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
for i in noise_payload.pairing_methods:
ctx.selected_pairing_methods.append(i)
if __debug__:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# TODO add credential recognition
paired: bool = True # TODO should be output from credential check
# send hanshake completion response
await ctx.write_handshake_message(
HANDSHAKE_COMP_RES,
thp_messages.get_handshake_completion_response(paired),
)
if paired:
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
ctx.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(
ctx: ChannelContext, message_length: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", ctx.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
await _handle_channel_message(ctx, message_length, message_type)
return
if session_id not in ctx.sessions:
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
raise ThpError("Unalloacted session")
session_state = ctx.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
raise ThpError("Unalloacted session")
ctx.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
ctx.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(ctx: ChannelContext, message_length: int) -> None:
from .pairing_context import PairingContext
if ctx.connection_context is None:
ctx.connection_context = PairingContext(ctx)
loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", ctx.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.publish(
MessageWithType(
message_type,
ctx.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
# 1. Check that message is expected with respect to the current state
# 2. Handle the message
pass
def _should_have_ctrl_byte_encrypted_transport(ctx: ChannelContext) -> bool:
if ctx.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
async def _handle_channel_message(
ctx: ChannelContext, message_length: int, message_type: int
) -> None:
buf = ctx.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
if not ThpCreateNewSession.is_type_of(message):
raise ThpError(
"The received message cannot be handled by channel itself. It must be sent to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
handler = get_handler_for_channel_message(message)
task = handler(ctx, message)
response_message = await task
# TODO handle
await ctx.write(response_message)
if __debug__:
log.debug(__name__, "_handle_channel_message - end")

@ -1,20 +1,25 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage import cache_thp
from storage.cache_thp import SessionThpCache
from trezor import log, loop, protobuf
from trezor.wire import message_handler, protocol_common
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
from ..protocol_common import Context, MessageWithType
from . import SessionState
from .channel import Channel
from . import ChannelContext, SessionState
if TYPE_CHECKING:
from typing import Container # pyright: ignore[reportShadowedImports]
from typing import ( # pyright: ignore[reportShadowedImports]
Any,
Awaitable,
Container,
)
pass
_EXIT_LOOP = True
_REPEAT_LOOP = False
class UnexpectedMessageWithType(Exception):
"""A message was received that is not part of the current workflow.
@ -29,29 +34,22 @@ class UnexpectedMessageWithType(Exception):
class SessionContext(Context):
def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None:
if channel.channel_id != session_cache.channel_id:
def __init__(
self, channel_ctx: ChannelContext, session_cache: SessionThpCache
) -> None:
if channel_ctx.channel_id != session_cache.channel_id:
raise Exception(
"The session has different channel id than the provided channel context!"
)
super().__init__(channel.iface, channel.channel_id)
self.channel = channel
super().__init__(channel_ctx.iface, channel_ctx.channel_id)
self.channel_ctx = channel_ctx
self.session_cache = session_cache
self.session_id = int.from_bytes(session_cache.session_id, "big")
self.incoming_message = loop.chan()
@classmethod
def create_new_session(cls, channel_context: Channel) -> "SessionContext":
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache)
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__:
log.debug(__name__, "handle - start (session_id: %d)", self.session_id)
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
self._handle_debug(is_debug_session)
take = self.incoming_message.take()
next_message: MessageWithType | None = None
@ -61,51 +59,70 @@ class SessionContext(Context):
# TODO modules = utils.unimport_begin()
while True:
try:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
try:
message: MessageWithType = await take
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
continue
else:
# Process the message from previous run.
message = next_message
next_message = None
try:
next_message = await message_handler.handle_single_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if (
next_message is None
and message.type not in AVOID_RESTARTING_FOR
):
# Shut down the loop if there is no next message waiting.
return # pylint: disable=lost-exception
if await self._handle_message(take, next_message, is_debug_session):
return
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
# Log and try again.
if __debug__:
log.exception(__name__, exc)
def _handle_debug(self, is_debug_session: bool) -> None:
log.debug(__name__, "handle - start (session_id: %d)", self.session_id)
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
async def _handle_message(
self,
take: Awaitable[Any],
next_message: MessageWithType | None,
is_debug_session: bool,
) -> bool:
try:
message = await self._get_message(take, next_message)
except protocol_common.WireError as e:
if __debug__:
log.exception(__name__, e)
await self.write(failure(e))
return _REPEAT_LOOP
try:
next_message = await message_handler.handle_single_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
if not __debug__ or not is_debug_session:
# Unload modules imported by the workflow. Should not raise.
# This is not done for the debug session because the snapshot taken
# in a debug session would clear modules which are in use by the
# workflow running on wire.
# TODO utils.unimport_end(modules)
if next_message is None and message.type not in AVOID_RESTARTING_FOR:
# Shut down the loop if there is no next message waiting.
return _EXIT_LOOP # pylint: disable=lost-exception
return _REPEAT_LOOP # pylint: disable=lost-exception
async def _get_message(
self, take: Awaitable[Any], next_message: MessageWithType | None
) -> MessageWithType:
if next_message is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one.
message: MessageWithType = await take
else:
# Process the message from previous run.
message = next_message
next_message = None
return message
async def read(
self,
expected_types: Container[int],
@ -131,7 +148,7 @@ class SessionContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id)
return await self.channel_ctx.write(msg, self.session_id)
# ACCESS TO SESSION DATA
@ -141,22 +158,3 @@ class SessionContext(Context):
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
if __debug__:
log.debug(
__name__,
"load_cached_sessions - loaded a total of %d sessions from cache",
len(cached_sessions),
)
for session in cached_sessions:
if session.channel_id == channel.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel, session)
loop.schedule(sessions[sid].handle())
return sessions

@ -0,0 +1,28 @@
from storage import cache_thp
from trezor import log, loop
from trezor.wire.thp import ChannelContext
from trezor.wire.thp.session_context import SessionContext
def create_new_session(channel_ctx: ChannelContext) -> SessionContext:
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
return SessionContext(channel_ctx, session_cache)
def load_cached_sessions(channel_ctx: ChannelContext) -> dict[int, SessionContext]:
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
if __debug__:
log.debug(
__name__,
"load_cached_sessions - loaded a total of %d sessions from cache",
len(cached_sessions),
)
for session in cached_sessions:
if session.channel_id == channel_ctx.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel_ctx, session)
loop.schedule(sessions[sid].handle())
return sessions

@ -7,6 +7,8 @@ from trezor.wire.thp.thp_messages import InitHeader
INIT_DATA_OFFSET = const(5)
CONT_DATA_OFFSET = const(3)
REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
MESSAGE_TYPE_LENGTH = const(2)
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]

@ -6,12 +6,12 @@ from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io, log, loop, utils
from .protocol_common import MessageWithId
from .thp import ChannelState, checksum, thp_messages
from .thp.channel import MAX_PAYLOAD_LEN, REPORT_LENGTH, Channel, load_cached_channels
from .thp import ChannelState, channel_manager, checksum, session_manager, thp_messages
from .thp.channel import Channel
from .thp.checksum import CHECKSUM_LENGTH
from .thp.thp_messages import CHANNEL_ALLOCATION_REQ, CODEC_V1, InitHeader
from .thp.thp_session import ThpError
from .thp.writer import write_payload_to_wire
from .thp.writer import MAX_PAYLOAD_LEN, REPORT_LENGTH, write_payload_to_wire
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
@ -33,7 +33,9 @@ def set_buffer(buffer):
async def thp_main_loop(iface: WireInterface, is_debug_session=False):
global CHANNELS
global _BUFFER
CHANNELS = load_cached_channels(_BUFFER)
CHANNELS = channel_manager.load_cached_channels(_BUFFER)
for ch in CHANNELS.values():
ch.sessions = session_manager.load_cached_sessions(ch)
read = loop.wait(iface.iface_num() | io.POLL_READ)
@ -55,18 +57,9 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
continue
if cid in CHANNELS:
channel = CHANNELS[cid]
if channel is None:
# TODO send error message to wire
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
await channel.receive_packet(packet)
continue
await _handle_unallocated(iface, cid)
await _handle_allocated(iface, cid, packet)
else:
await _handle_unallocated(iface, cid)
except ThpError as e:
if __debug__:
@ -76,7 +69,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False):
async def _handle_broadcast(
iface: WireInterface, ctrl_byte, packet
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType
) -> MessageWithId | None:
global _BUFFER
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
@ -91,7 +84,7 @@ async def _handle_broadcast(
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid")
new_channel: Channel = Channel.create_new_channel(iface, _BUFFER)
new_channel: Channel = channel_manager.create_new_channel(iface, _BUFFER)
cid = int.from_bytes(new_channel.channel_id, "big")
CHANNELS[cid] = new_channel
@ -108,6 +101,21 @@ async def _handle_broadcast(
await write_payload_to_wire(iface, response_header, response_data + chksum)
async def _handle_allocated(
iface: WireInterface, cid: int, packet: utils.BufferType
) -> None:
channel = CHANNELS[cid]
if channel is None:
# TODO send error message to wire
raise ThpError("Invalid state of a channel")
if channel.iface is not iface:
# TODO send error message to wire
raise ThpError("Channel has different WireInterface")
if channel.get_channel_state() != ChannelState.UNALLOCATED:
await channel.receive_packet(packet)
async def _handle_unallocated(iface, cid) -> MessageWithId | None:
data = thp_messages.get_error_unallocated_channel()
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)

Loading…
Cancel
Save