Refactor channel

M1nd3r/thp5
M1nd3r 2 months ago
parent 7f455f9931
commit 3317921365

@ -1,7 +1,7 @@
import ustruct # pyright: ignore[reportMissingModuleSource] import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource] from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
from ubinascii import hexlify from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource]
import usb import usb
from storage import cache_thp from storage import cache_thp
@ -136,25 +136,16 @@ class Channel(Context):
async def _handle_cont_packet(self, packet: utils.BufferType): async def _handle_cont_packet(self, packet: utils.BufferType):
print("cont") print("cont")
if not self.is_cont_packet_expected: if not self.is_cont_packet_expected:
return # Continuation packet is not expected, ignoring raise ThpError("Continuation packet is not expected, ignoring")
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
async def _handle_completed_message(self): async def _handle_completed_message(self) -> None:
print("handling completed message") print("handling completed message")
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache)) print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
msg_len = payload_length + INIT_DATA_OFFSET message_length = payload_length + INIT_DATA_OFFSET
print("checksum check") self._check_checksum(message_length)
# printBytes(self.buffer)
if not checksum.is_valid(
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
):
# checksum is not valid -> ignore message
self._todo_clear_buffer()
return
# Synchronization process # Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4 sync_bit = (ctrl_byte & 0x10) >> 4
@ -184,103 +175,123 @@ class Channel(Context):
sync_bit, sync_bit,
) )
await self._sendAck(sync_bit) await self._sendAck(sync_bit)
print("___set receive bit to", 1 - sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit) THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
self._handle_valid_message(payload_length, message_length, ctrl_byte)
print("end handle completed message")
def _check_checksum(self, message_length: int):
print("checksum check")
if not checksum.is_valid(
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
data=self.buffer[: message_length - CHECKSUM_LENGTH],
):
self._todo_clear_buffer()
raise ThpError("Invalid checksum, ignoring message.")
def _handle_valid_message(
self, payload_length: int, message_length: int, ctrl_byte: int
) -> None:
state = self.get_channel_state() state = self.get_channel_state()
if __debug__: if __debug__:
log.debug(__name__, _state_to_str(state)) log.debug(__name__, _state_to_str(state))
if state is ChannelState.TH1: if state is ChannelState.TH1:
if not _is_ctrl_byte_handshake_init: self._handle_state_TH1(payload_length, message_length)
raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError(
"Message received is not a valid handshake init request!"
)
host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(
self.channel_cache, host_ephemeral_key
)
# TODO send ack in response
# TODO send handshake init response message
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
)
self.set_channel_state(ChannelState.TH2)
return
if not _is_ctrl_byte_encrypted_transport(ctrl_byte): if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
print("Message is not encrypted. Ignoring")
# TODO ignore message
self._todo_clear_buffer() self._todo_clear_buffer()
return raise ThpError("Message is not encrypted. Ignoring")
if state is ChannelState.ENCRYPTED_TRANSPORT: if state is ChannelState.ENCRYPTED_TRANSPORT:
self._decrypt_buffer() self._handle_state_ENCRYPTED_TRANSPORT(message_length)
session_id, message_type = ustruct.unpack(
">BH", self.buffer[INIT_DATA_OFFSET:] if state is ChannelState.TH2:
self._handle_state_TH2(message_length)
def _handle_state_TH1(self, payload_length: int, message_length: int) -> None:
if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
# TODO send ack in response
# TODO send handshake init response message
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
) )
if session_id == 0: )
try: self.set_channel_state(ChannelState.TH2)
buf = self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH] return
expected_type = protobuf.type_for_wire(message_type) def _handle_state_TH2(self, message_length: int) -> None:
message = message_handler.wrap_protobuf_load(buf, expected_type) print("th2 branche")
print(message) host_encrypted_static_pubkey = self.buffer[
# TODO handle other messages than CreateNewSession INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
assert isinstance(message, ThpCreateNewSession) ]
print("passphrase:", message.passphrase) handshake_completion_request_noise_payload = self.buffer[
# await thp_messages.handle_CreateNewSession(message) INIT_DATA_OFFSET
if message.passphrase is not None: + KEY_LENGTH
self.create_new_session(message.passphrase) + TAG_LENGTH : message_length
else: - CHECKSUM_LENGTH
self.create_new_session() ]
except Exception as e: print(
print("Proč??") host_encrypted_static_pubkey,
print(e) handshake_completion_request_noise_payload,
return ) # TODO remove
# TODO not finished # TODO send ack in response
# TODO send hanshake completion response
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
)
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
if session_id not in self.sessions: def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
raise Exception("Unalloacted session") # TODO send error message self._decrypt_buffer()
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
self._handle_channel_comms(message_length, message_type)
return
session_state = self.sessions[session_id].get_session_state() if session_id not in self.sessions:
if session_state is SessionState.UNALLOCATED: raise ThpError("Unalloacted session")
raise Exception("Unalloacted session") # TODO send error message
self.sessions[session_id].incoming_message.publish( session_state = self.sessions[session_id].get_session_state()
MessageWithType( if session_state is SessionState.UNALLOCATED:
message_type, raise ThpError("Unalloacted session")
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
)
)
if state is ChannelState.TH2: self.sessions[session_id].incoming_message.publish(
print("th2 branche") MessageWithType(
host_encrypted_static_pubkey = self.buffer[ message_type,
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH],
]
handshake_completion_request_noise_payload = self.buffer[
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
]
print(
host_encrypted_static_pubkey,
handshake_completion_request_noise_payload,
) # TODO remove
# TODO send ack in response
# TODO send hanshake completion response
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
) )
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) )
print("end handle completed message")
def _handle_channel_comms(self, message_length: int, message_type: int) -> None:
try:
buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
print(message)
# TODO handle other messages than CreateNewSession
assert isinstance(message, ThpCreateNewSession)
print("passphrase:", message.passphrase)
# await thp_messages.handle_CreateNewSession(message)
if message.passphrase is not None:
self.create_new_session(message.passphrase)
else:
self.create_new_session()
except Exception as e:
print("Proč??")
print(e)
# TODO not finished
def _decrypt(self, payload) -> bytes: def _decrypt(self, payload) -> bytes:
return payload # TODO add decryption process return payload # TODO add decryption process

Loading…
Cancel
Save