1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

refactor(core): refactor channel

This commit is contained in:
M1nd3r 2024-04-02 17:36:14 +02:00
parent 3f590bc11d
commit e7f5f3d3f2

View File

@ -1,7 +1,7 @@
import ustruct # pyright: ignore[reportMissingModuleSource]
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
from ubinascii import hexlify
from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource]
import usb
from storage import cache_thp
@ -136,25 +136,16 @@ class Channel(Context):
async def _handle_cont_packet(self, packet: utils.BufferType):
print("cont")
if not self.is_cont_packet_expected:
return # Continuation packet is not expected, ignoring
raise ThpError("Continuation packet is not expected, ignoring")
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
async def _handle_completed_message(self):
async def _handle_completed_message(self) -> None:
print("handling completed message")
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
msg_len = payload_length + INIT_DATA_OFFSET
message_length = payload_length + INIT_DATA_OFFSET
print("checksum check")
# printBytes(self.buffer)
if not checksum.is_valid(
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
):
# checksum is not valid -> ignore message
self._todo_clear_buffer()
return
self._check_checksum(message_length)
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
@ -184,103 +175,123 @@ class Channel(Context):
sync_bit,
)
await self._sendAck(sync_bit)
print("___set receive bit to", 1 - sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
self._handle_valid_message(payload_length, message_length, ctrl_byte)
print("end handle completed message")
def _check_checksum(self, message_length: int):
print("checksum check")
if not checksum.is_valid(
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
data=self.buffer[: message_length - CHECKSUM_LENGTH],
):
self._todo_clear_buffer()
raise ThpError("Invalid checksum, ignoring message.")
def _handle_valid_message(
self, payload_length: int, message_length: int, ctrl_byte: int
) -> None:
state = self.get_channel_state()
if __debug__:
log.debug(__name__, _state_to_str(state))
if state is ChannelState.TH1:
if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError(
"Message received is not a valid handshake init request!"
)
host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(
self.channel_cache, host_ephemeral_key
)
# TODO send ack in response
# TODO send handshake init response message
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
)
self.set_channel_state(ChannelState.TH2)
return
self._handle_state_TH1(payload_length, message_length)
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
print("Message is not encrypted. Ignoring")
# TODO ignore message
self._todo_clear_buffer()
return
raise ThpError("Message is not encrypted. Ignoring")
if state is ChannelState.ENCRYPTED_TRANSPORT:
self._decrypt_buffer()
session_id, message_type = ustruct.unpack(
">BH", self.buffer[INIT_DATA_OFFSET:]
)
if session_id == 0:
try:
buf = self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
print(message)
# TODO handle other messages than CreateNewSession
assert isinstance(message, ThpCreateNewSession)
print("passphrase:", message.passphrase)
# await thp_messages.handle_CreateNewSession(message)
if message.passphrase is not None:
self.create_new_session(message.passphrase)
else:
self.create_new_session()
except Exception as e:
print("Proč??")
print(e)
return
# TODO not finished
if session_id not in self.sessions:
raise Exception("Unalloacted session") # TODO send error message
session_state = self.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
raise Exception("Unalloacted session") # TODO send error message
self.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
)
)
self._handle_state_ENCRYPTED_TRANSPORT(message_length)
if state is ChannelState.TH2:
print("th2 branche")
host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = self.buffer[
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
]
print(
host_encrypted_static_pubkey,
handshake_completion_request_noise_payload,
) # TODO remove
# TODO send ack in response
# TODO send hanshake completion response
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
self._handle_state_TH2(message_length)
def _handle_state_TH1(self, payload_length: int, message_length: int) -> None:
if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
# TODO send ack in response
# TODO send handshake init response message
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
print("end handle completed message")
)
self.set_channel_state(ChannelState.TH2)
return
def _handle_state_TH2(self, message_length: int) -> None:
print("th2 branche")
host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
]
handshake_completion_request_noise_payload = self.buffer[
INIT_DATA_OFFSET
+ KEY_LENGTH
+ TAG_LENGTH : message_length
- CHECKSUM_LENGTH
]
print(
host_encrypted_static_pubkey,
handshake_completion_request_noise_payload,
) # TODO remove
# TODO send ack in response
# TODO send hanshake completion response
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
)
)
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
self._decrypt_buffer()
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
self._handle_channel_comms(message_length, message_type)
return
if session_id not in self.sessions:
raise ThpError("Unalloacted session")
session_state = self.sessions[session_id].get_session_state()
if session_state is SessionState.UNALLOCATED:
raise ThpError("Unalloacted session")
self.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH],
)
)
def _handle_channel_comms(self, message_length: int, message_type: int) -> None:
try:
buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH]
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
print(message)
# TODO handle other messages than CreateNewSession
assert isinstance(message, ThpCreateNewSession)
print("passphrase:", message.passphrase)
# await thp_messages.handle_CreateNewSession(message)
if message.passphrase is not None:
self.create_new_session(message.passphrase)
else:
self.create_new_session()
except Exception as e:
print("Proč??")
print(e)
# TODO not finished
def _decrypt(self, payload) -> bytes:
return payload # TODO add decryption process