You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/trezor/wire/thp/received_message_handler.py

351 lines
11 KiB

import ustruct # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH
from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType
from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler
from trezor.wire.protocol_common import MessageWithType
from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.checksum import CHECKSUM_LENGTH
from trezor.wire.thp.crypto import PUBKEY_LENGTH
from trezor.wire.thp.thp_messages import (
ACK_MESSAGE,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_RES,
InitHeader,
)
from . import (
ChannelState,
SessionState,
checksum,
control_byte,
is_channel_state_pairing,
)
from . import thp_session as THP
from .thp_session import ThpError
from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire
if TYPE_CHECKING:
from trezor.messages import ThpHandshakeCompletionReqNoisePayload
from . import ChannelContext
if __debug__:
from . import state_to_str
async def handle_received_message(
ctx: ChannelContext, message_buffer: utils.BufferType
) -> None:
"""Handle a message received from the channel."""
if __debug__:
log.debug(__name__, "handle_received_message")
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", message_buffer)
message_length = payload_length + INIT_DATA_OFFSET
_check_checksum(message_length, message_buffer)
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
if __debug__:
log.debug(
__name__,
"handle_completed_message - sync bit of message: %d",
sync_bit,
)
# 1: Handle ACKs
if control_byte.is_ack(ctrl_byte):
await _handle_ack(ctx, sync_bit)
return
if _should_have_ctrl_byte_encrypted_transport(
ctx
) and not control_byte.is_encrypted_transport(ctrl_byte):
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(ctx.channel_cache):
if __debug__:
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
await _send_ack(ctx, sync_bit)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
await _send_ack(ctx, sync_bit)
THP.sync_set_receive_expected_bit(ctx.channel_cache, 1 - sync_bit)
await _handle_message_to_app_or_channel(
ctx, payload_length, message_length, ctrl_byte, sync_bit
)
if __debug__:
log.debug(__name__, "handle_received_message - end")
async def _send_ack(ctx: ChannelContext, ack_bit: int) -> None:
ctrl_byte = control_byte.add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader(ctrl_byte, ctx.get_channel_id_int(), CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes())
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
ctx.get_channel_id_int(),
ack_bit,
)
await write_payload_to_wire(ctx.iface, header, chksum)
def _check_checksum(message_length: int, message_buffer: utils.BufferType):
if __debug__:
log.debug(__name__, "check_checksum")
if not checksum.is_valid(
checksum=message_buffer[message_length - CHECKSUM_LENGTH : message_length],
data=message_buffer[: message_length - CHECKSUM_LENGTH],
):
if __debug__:
log.debug(__name__, "Invalid checksum, ignoring message.")
raise ThpError("Invalid checksum, ignoring message.")
# TEST THIS
async def _handle_ack(ctx: ChannelContext, sync_bit: int):
if not ack_handler.is_ack_valid(ctx.channel_cache, sync_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
if ctx.waiting_for_ack_timeout is not None:
ctx.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, 'Closed "waiting for ack" task')
THP.sync_set_can_send_message(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in terminations of this function - any code after
# this await might not be executed
async def _handle_message_to_app_or_channel(
ctx: ChannelContext,
payload_length: int,
message_length: int,
ctrl_byte: int,
sync_bit: int,
) -> None:
state = ctx.get_channel_state()
if __debug__:
log.debug(__name__, "state: %s", state_to_str(state))
if state is ChannelState.ENCRYPTED_TRANSPORT:
await _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
return
if state is ChannelState.TH1:
await _handle_state_TH1(
ctx, payload_length, message_length, ctrl_byte, sync_bit
)
return
if state is ChannelState.TH2:
await _handle_state_TH2(ctx, message_length, ctrl_byte, sync_bit)
return
if is_channel_state_pairing(state):
await _handle_pairing(ctx, message_length)
return
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
ctx: ChannelContext,
payload_length: int,
message_length: int,
ctrl_byte: int,
sync_bit: int,
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not control_byte.is_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
ctx.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(ctx.channel_cache, host_ephemeral_key)
# send handshake init response message
await ctx.write_handshake_message(
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
)
ctx.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(
ctx: ChannelContext, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not control_byte.is_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
host_encrypted_static_pubkey = ctx.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = ctx.buffer[
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
]
noise_payload = thp_messages.decode_message(
ctx.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
0,
"ThpHandshakeCompletionReqNoisePayload",
)
if TYPE_CHECKING:
assert ThpHandshakeCompletionReqNoisePayload.is_type_of(noise_payload)
for i in noise_payload.pairing_methods:
ctx.selected_pairing_methods.append(i)
if __debug__:
log.debug(
__name__,
"host static pubkey: %s, noise payload: %s",
utils.get_bytes_as_str(host_encrypted_static_pubkey),
utils.get_bytes_as_str(handshake_completion_request_noise_payload),
)
# TODO add credential recognition
paired: bool = True # TODO should be output from credential check
# send hanshake completion response
await ctx.write_handshake_message(
HANDSHAKE_COMP_RES,
thp_messages.get_handshake_completion_response(paired),
)
if paired:
ctx.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
else:
ctx.set_channel_state(ChannelState.TP1)
async def _handle_state_ENCRYPTED_TRANSPORT(
ctx: ChannelContext, message_length: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", ctx.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
await _handle_channel_message(ctx, message_length, message_type)
return
if session_id not in ctx.sessions:
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
raise ThpError("Unalloacted session")
session_state = ctx.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
await ctx.write_error(FailureType.ThpUnallocatedSession, "Unallocated session")
raise ThpError("Unalloacted session")
ctx.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
ctx.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
async def _handle_pairing(ctx: ChannelContext, message_length: int) -> None:
from .pairing_context import PairingContext
if ctx.connection_context is None:
ctx.connection_context = PairingContext(ctx)
loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length)
message_type = ustruct.unpack(
">H", ctx.buffer[INIT_DATA_OFFSET + SESSION_ID_LENGTH :]
)[0]
ctx.connection_context.incoming_message.publish(
MessageWithType(
message_type,
ctx.buffer[
INIT_DATA_OFFSET
+ MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length
- CHECKSUM_LENGTH
- TAG_LENGTH
],
)
)
# 1. Check that message is expected with respect to the current state
# 2. Handle the message
pass
def _should_have_ctrl_byte_encrypted_transport(ctx: ChannelContext) -> bool:
if ctx.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
async def _handle_channel_message(
ctx: ChannelContext, message_length: int, message_type: int
) -> None:
buf = ctx.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
if not ThpCreateNewSession.is_type_of(message):
raise ThpError(
"The received message cannot be handled by channel itself. It must be sent to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
handler = get_handler_for_channel_message(message)
task = handler(ctx, message)
response_message = await task
# TODO handle
await ctx.write(response_message)
if __debug__:
log.debug(__name__, "_handle_channel_message - end")