|
|
@ -21,7 +21,10 @@ from .thp_messages import (
|
|
|
|
CONTINUATION_PACKET,
|
|
|
|
CONTINUATION_PACKET,
|
|
|
|
ENCRYPTED_TRANSPORT,
|
|
|
|
ENCRYPTED_TRANSPORT,
|
|
|
|
ERROR,
|
|
|
|
ERROR,
|
|
|
|
HANDSHAKE_INIT,
|
|
|
|
HANDSHAKE_COMP_REQ,
|
|
|
|
|
|
|
|
HANDSHAKE_COMP_RES,
|
|
|
|
|
|
|
|
HANDSHAKE_INIT_REQ,
|
|
|
|
|
|
|
|
HANDSHAKE_INIT_RES,
|
|
|
|
InitHeader,
|
|
|
|
InitHeader,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from .thp_session import ThpError
|
|
|
|
from .thp_session import ThpError
|
|
|
@ -193,8 +196,9 @@ class Channel(Context):
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if self._should_be_encrypted() and not _is_ctrl_byte_encrypted_transport(
|
|
|
|
if (
|
|
|
|
ctrl_byte
|
|
|
|
self._should_have_ctrl_byte_encrypted_transport()
|
|
|
|
|
|
|
|
and not _is_ctrl_byte_encrypted_transport(ctrl_byte)
|
|
|
|
):
|
|
|
|
):
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
raise ThpError("Message is not encrypted. Ignoring")
|
|
|
|
raise ThpError("Message is not encrypted. Ignoring")
|
|
|
@ -243,11 +247,13 @@ class Channel(Context):
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH1:
|
|
|
|
if state is ChannelState.TH1:
|
|
|
|
await self._handle_state_TH1(payload_length, message_length, sync_bit)
|
|
|
|
await self._handle_state_TH1(
|
|
|
|
|
|
|
|
payload_length, message_length, ctrl_byte, sync_bit
|
|
|
|
|
|
|
|
)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH2:
|
|
|
|
if state is ChannelState.TH2:
|
|
|
|
await self._handle_state_TH2(message_length, sync_bit)
|
|
|
|
await self._handle_state_TH2(message_length, ctrl_byte, sync_bit)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
if is_channel_state_pairing(state):
|
|
|
|
if is_channel_state_pairing(state):
|
|
|
|
await self._handle_pairing(message_length)
|
|
|
|
await self._handle_pairing(message_length)
|
|
|
@ -255,9 +261,11 @@ class Channel(Context):
|
|
|
|
raise ThpError("Unimplemented channel state")
|
|
|
|
raise ThpError("Unimplemented channel state")
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_state_TH1(
|
|
|
|
async def _handle_state_TH1(
|
|
|
|
self, payload_length: int, message_length: int, sync_bit: int
|
|
|
|
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
|
|
|
|
) -> None:
|
|
|
|
) -> None:
|
|
|
|
if not _is_ctrl_byte_handshake_init:
|
|
|
|
if __debug__:
|
|
|
|
|
|
|
|
log.debug(__name__, "handle_state_TH1")
|
|
|
|
|
|
|
|
if not _is_ctrl_byte_handshake_init_req(ctrl_byte):
|
|
|
|
raise ThpError("Message received is not a handshake init request!")
|
|
|
|
raise ThpError("Message received is not a handshake init request!")
|
|
|
|
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
|
|
|
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
|
|
|
raise ThpError("Message received is not a valid handshake init request!")
|
|
|
|
raise ThpError("Message received is not a valid handshake init request!")
|
|
|
@ -269,15 +277,19 @@ class Channel(Context):
|
|
|
|
# send handshake init response message
|
|
|
|
# send handshake init response message
|
|
|
|
loop.schedule(
|
|
|
|
loop.schedule(
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
thp_messages.get_handshake_init_response()
|
|
|
|
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.set_channel_state(ChannelState.TH2)
|
|
|
|
self.set_channel_state(ChannelState.TH2)
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
|
|
|
|
async def _handle_state_TH2(
|
|
|
|
|
|
|
|
self, message_length: int, ctrl_byte: int, sync_bit: int
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
if __debug__:
|
|
|
|
if __debug__:
|
|
|
|
log.debug(__name__, "handle_state_TH2")
|
|
|
|
log.debug(__name__, "handle_state_TH2")
|
|
|
|
|
|
|
|
if not _is_ctrl_byte_handshake_comp_req(ctrl_byte):
|
|
|
|
|
|
|
|
raise ThpError("Message received is not a handshake completion request!")
|
|
|
|
host_encrypted_static_pubkey = self.buffer[
|
|
|
|
host_encrypted_static_pubkey = self.buffer[
|
|
|
|
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
|
|
|
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
|
|
|
]
|
|
|
|
]
|
|
|
@ -298,7 +310,7 @@ class Channel(Context):
|
|
|
|
# send hanshake completion response
|
|
|
|
# send hanshake completion response
|
|
|
|
loop.schedule(
|
|
|
|
loop.schedule(
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
thp_messages.get_handshake_completion_response()
|
|
|
|
HANDSHAKE_COMP_RES, thp_messages.get_handshake_completion_response()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
|
|
|
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
|
|
@ -361,8 +373,12 @@ class Channel(Context):
|
|
|
|
# 2. Handle the message
|
|
|
|
# 2. Handle the message
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def _should_be_encrypted(self) -> bool:
|
|
|
|
def _should_have_ctrl_byte_encrypted_transport(self) -> bool:
|
|
|
|
if self.get_channel_state() in [ChannelState.UNALLOCATED, ChannelState.TH1]:
|
|
|
|
if self.get_channel_state() in [
|
|
|
|
|
|
|
|
ChannelState.UNALLOCATED,
|
|
|
|
|
|
|
|
ChannelState.TH1,
|
|
|
|
|
|
|
|
ChannelState.TH2,
|
|
|
|
|
|
|
|
]:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
@ -512,15 +528,19 @@ class Channel(Context):
|
|
|
|
payload_length = payload_length + TAG_LENGTH
|
|
|
|
payload_length = payload_length + TAG_LENGTH
|
|
|
|
|
|
|
|
|
|
|
|
loop.schedule(
|
|
|
|
loop.schedule(
|
|
|
|
self._write_encrypted_payload_loop(memoryview(self.buffer[:payload_length]))
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
|
|
|
|
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
|
|
|
async def _write_encrypted_payload_loop(
|
|
|
|
|
|
|
|
self, ctrl_byte: int, payload: bytes
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
if __debug__:
|
|
|
|
if __debug__:
|
|
|
|
log.debug(__name__, "write_encrypted_payload_loop")
|
|
|
|
log.debug(__name__, "write_encrypted_payload_loop")
|
|
|
|
payload_len = len(payload) + CHECKSUM_LENGTH
|
|
|
|
payload_len = len(payload) + CHECKSUM_LENGTH
|
|
|
|
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
|
|
|
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
|
|
|
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
|
|
|
|
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit)
|
|
|
|
header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
|
|
|
|
header = InitHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
|
|
|
|
chksum = checksum.compute(header.to_bytes() + payload)
|
|
|
|
chksum = checksum.compute(header.to_bytes() + payload)
|
|
|
|
payload = payload + chksum
|
|
|
|
payload = payload + chksum
|
|
|
@ -672,8 +692,12 @@ def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
|
|
|
|
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
|
|
|
|
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool:
|
|
|
|
def _is_ctrl_byte_handshake_init_req(ctrl_byte: int) -> bool:
|
|
|
|
return ctrl_byte & 0xEF == HANDSHAKE_INIT
|
|
|
|
return ctrl_byte & 0xEF == HANDSHAKE_INIT_REQ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_ctrl_byte_handshake_comp_req(ctrl_byte: int) -> bool:
|
|
|
|
|
|
|
|
return ctrl_byte & 0xEF == HANDSHAKE_COMP_REQ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
|
|
|
|
def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
|
|
|
|