1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): adjust control bytes

This commit is contained in:
M1nd3r 2024-04-09 17:02:04 +02:00
parent 6b0c0accd7
commit 58b9060ec6
2 changed files with 47 additions and 19 deletions

View File

@ -21,7 +21,10 @@ from .thp_messages import (
CONTINUATION_PACKET,
ENCRYPTED_TRANSPORT,
ERROR,
HANDSHAKE_INIT,
HANDSHAKE_COMP_REQ,
HANDSHAKE_COMP_RES,
HANDSHAKE_INIT_REQ,
HANDSHAKE_INIT_RES,
InitHeader,
)
from .thp_session import ThpError
@ -193,8 +196,9 @@ class Channel(Context):
self._todo_clear_buffer()
return
if self._should_be_encrypted() and not _is_ctrl_byte_encrypted_transport(
ctrl_byte
if (
self._should_have_ctrl_byte_encrypted_transport()
and not _is_ctrl_byte_encrypted_transport(ctrl_byte)
):
self._todo_clear_buffer()
raise ThpError("Message is not encrypted. Ignoring")
@ -243,11 +247,13 @@ class Channel(Context):
return
if state is ChannelState.TH1:
await self._handle_state_TH1(payload_length, message_length, sync_bit)
await self._handle_state_TH1(
payload_length, message_length, ctrl_byte, sync_bit
)
return
if state is ChannelState.TH2:
await self._handle_state_TH2(message_length, sync_bit)
await self._handle_state_TH2(message_length, ctrl_byte, sync_bit)
return
if is_channel_state_pairing(state):
await self._handle_pairing(message_length)
@ -255,9 +261,11 @@ class Channel(Context):
raise ThpError("Unimplemented channel state")
async def _handle_state_TH1(
self, payload_length: int, message_length: int, sync_bit: int
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if not _is_ctrl_byte_handshake_init:
if __debug__:
log.debug(__name__, "handle_state_TH1")
if not _is_ctrl_byte_handshake_init_req(ctrl_byte):
raise ThpError("Message received is not a handshake init request!")
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
@ -269,15 +277,19 @@ class Channel(Context):
# send handshake init response message
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
)
)
self.set_channel_state(ChannelState.TH2)
return
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
async def _handle_state_TH2(
self, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
if __debug__:
log.debug(__name__, "handle_state_TH2")
if not _is_ctrl_byte_handshake_comp_req(ctrl_byte):
raise ThpError("Message received is not a handshake completion request!")
host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
@ -298,7 +310,7 @@ class Channel(Context):
# send hanshake completion response
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_completion_response()
HANDSHAKE_COMP_RES, thp_messages.get_handshake_completion_response()
)
)
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
@ -361,8 +373,12 @@ class Channel(Context):
# 2. Handle the message
pass
def _should_be_encrypted(self) -> bool:
if self.get_channel_state() in [ChannelState.UNALLOCATED, ChannelState.TH1]:
def _should_have_ctrl_byte_encrypted_transport(self) -> bool:
if self.get_channel_state() in [
ChannelState.UNALLOCATED,
ChannelState.TH1,
ChannelState.TH2,
]:
return False
return True
@ -512,15 +528,19 @@ class Channel(Context):
payload_length = payload_length + TAG_LENGTH
loop.schedule(
self._write_encrypted_payload_loop(memoryview(self.buffer[:payload_length]))
self._write_encrypted_payload_loop(
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
)
)
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
async def _write_encrypted_payload_loop(
self, ctrl_byte: int, payload: bytes
) -> None:
if __debug__:
log.debug(__name__, "write_encrypted_payload_loop")
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit)
header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
chksum = checksum.compute(header.to_bytes() + payload)
payload = payload + chksum
@ -672,8 +692,12 @@ def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT
def _is_ctrl_byte_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ
def _is_ctrl_byte_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:

View File

@ -9,8 +9,12 @@ from ..protocol_common import Message
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
ENCRYPTED_TRANSPORT = 0x02
HANDSHAKE_INIT = 0x00
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
ACK_MESSAGE = 0x20
ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40