mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
refactor(core): refactor channel
This commit is contained in:
parent
3f590bc11d
commit
e7f5f3d3f2
@ -1,7 +1,7 @@
|
||||
import ustruct # pyright: ignore[reportMissingModuleSource]
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
|
||||
from ubinascii import hexlify
|
||||
from ubinascii import hexlify # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
import usb
|
||||
from storage import cache_thp
|
||||
@ -136,25 +136,16 @@ class Channel(Context):
|
||||
async def _handle_cont_packet(self, packet: utils.BufferType):
|
||||
print("cont")
|
||||
if not self.is_cont_packet_expected:
|
||||
return # Continuation packet is not expected, ignoring
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET)
|
||||
|
||||
async def _handle_completed_message(self):
|
||||
async def _handle_completed_message(self) -> None:
|
||||
print("handling completed message")
|
||||
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
||||
msg_len = payload_length + INIT_DATA_OFFSET
|
||||
message_length = payload_length + INIT_DATA_OFFSET
|
||||
|
||||
print("checksum check")
|
||||
# printBytes(self.buffer)
|
||||
|
||||
if not checksum.is_valid(
|
||||
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
|
||||
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
|
||||
):
|
||||
# checksum is not valid -> ignore message
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
self._check_checksum(message_length)
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
@ -184,103 +175,123 @@ class Channel(Context):
|
||||
sync_bit,
|
||||
)
|
||||
await self._sendAck(sync_bit)
|
||||
print("___set receive bit to", 1 - sync_bit)
|
||||
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
|
||||
|
||||
self._handle_valid_message(payload_length, message_length, ctrl_byte)
|
||||
print("end handle completed message")
|
||||
|
||||
def _check_checksum(self, message_length: int):
|
||||
print("checksum check")
|
||||
if not checksum.is_valid(
|
||||
checksum=self.buffer[message_length - CHECKSUM_LENGTH : message_length],
|
||||
data=self.buffer[: message_length - CHECKSUM_LENGTH],
|
||||
):
|
||||
self._todo_clear_buffer()
|
||||
raise ThpError("Invalid checksum, ignoring message.")
|
||||
|
||||
def _handle_valid_message(
|
||||
self, payload_length: int, message_length: int, ctrl_byte: int
|
||||
) -> None:
|
||||
state = self.get_channel_state()
|
||||
if __debug__:
|
||||
log.debug(__name__, _state_to_str(state))
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
if not _is_ctrl_byte_handshake_init:
|
||||
raise ThpError("Message received is not a handshake init request!")
|
||||
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
raise ThpError(
|
||||
"Message received is not a valid handshake init request!"
|
||||
)
|
||||
host_ephemeral_key = bytearray(
|
||||
self.buffer[INIT_DATA_OFFSET : msg_len - CHECKSUM_LENGTH]
|
||||
)
|
||||
cache_thp.set_channel_host_ephemeral_key(
|
||||
self.channel_cache, host_ephemeral_key
|
||||
)
|
||||
# TODO send ack in response
|
||||
# TODO send handshake init response message
|
||||
loop.schedule(
|
||||
self._write_encrypted_payload_loop(
|
||||
thp_messages.get_handshake_init_response()
|
||||
)
|
||||
)
|
||||
self.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
self._handle_state_TH1(payload_length, message_length)
|
||||
|
||||
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
||||
print("Message is not encrypted. Ignoring")
|
||||
# TODO ignore message
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
raise ThpError("Message is not encrypted. Ignoring")
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
self._decrypt_buffer()
|
||||
session_id, message_type = ustruct.unpack(
|
||||
">BH", self.buffer[INIT_DATA_OFFSET:]
|
||||
)
|
||||
if session_id == 0:
|
||||
try:
|
||||
buf = self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH]
|
||||
|
||||
expected_type = protobuf.type_for_wire(message_type)
|
||||
message = message_handler.wrap_protobuf_load(buf, expected_type)
|
||||
print(message)
|
||||
# TODO handle other messages than CreateNewSession
|
||||
assert isinstance(message, ThpCreateNewSession)
|
||||
print("passphrase:", message.passphrase)
|
||||
# await thp_messages.handle_CreateNewSession(message)
|
||||
if message.passphrase is not None:
|
||||
self.create_new_session(message.passphrase)
|
||||
else:
|
||||
self.create_new_session()
|
||||
except Exception as e:
|
||||
print("Proč??")
|
||||
print(e)
|
||||
return
|
||||
# TODO not finished
|
||||
|
||||
if session_id not in self.sessions:
|
||||
raise Exception("Unalloacted session") # TODO send error message
|
||||
|
||||
session_state = self.sessions[session_id].get_session_state()
|
||||
if session_state is SessionState.UNALLOCATED:
|
||||
raise Exception("Unalloacted session") # TODO send error message
|
||||
|
||||
self.sessions[session_id].incoming_message.publish(
|
||||
MessageWithType(
|
||||
message_type,
|
||||
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
|
||||
)
|
||||
)
|
||||
self._handle_state_ENCRYPTED_TRANSPORT(message_length)
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
print("th2 branche")
|
||||
host_encrypted_static_pubkey = self.buffer[
|
||||
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
handshake_completion_request_noise_payload = self.buffer[
|
||||
INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH : msg_len - CHECKSUM_LENGTH
|
||||
]
|
||||
print(
|
||||
host_encrypted_static_pubkey,
|
||||
handshake_completion_request_noise_payload,
|
||||
) # TODO remove
|
||||
# TODO send ack in response
|
||||
# TODO send hanshake completion response
|
||||
loop.schedule(
|
||||
self._write_encrypted_payload_loop(
|
||||
thp_messages.get_handshake_init_response()
|
||||
)
|
||||
self._handle_state_TH2(message_length)
|
||||
|
||||
def _handle_state_TH1(self, payload_length: int, message_length: int) -> None:
|
||||
if not _is_ctrl_byte_handshake_init:
|
||||
raise ThpError("Message received is not a handshake init request!")
|
||||
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
raise ThpError("Message received is not a valid handshake init request!")
|
||||
host_ephemeral_key = bytearray(
|
||||
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
|
||||
)
|
||||
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
|
||||
# TODO send ack in response
|
||||
# TODO send handshake init response message
|
||||
loop.schedule(
|
||||
self._write_encrypted_payload_loop(
|
||||
thp_messages.get_handshake_init_response()
|
||||
)
|
||||
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
print("end handle completed message")
|
||||
)
|
||||
self.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
|
||||
def _handle_state_TH2(self, message_length: int) -> None:
|
||||
print("th2 branche")
|
||||
host_encrypted_static_pubkey = self.buffer[
|
||||
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
]
|
||||
handshake_completion_request_noise_payload = self.buffer[
|
||||
INIT_DATA_OFFSET
|
||||
+ KEY_LENGTH
|
||||
+ TAG_LENGTH : message_length
|
||||
- CHECKSUM_LENGTH
|
||||
]
|
||||
print(
|
||||
host_encrypted_static_pubkey,
|
||||
handshake_completion_request_noise_payload,
|
||||
) # TODO remove
|
||||
# TODO send ack in response
|
||||
# TODO send hanshake completion response
|
||||
loop.schedule(
|
||||
self._write_encrypted_payload_loop(
|
||||
thp_messages.get_handshake_init_response()
|
||||
)
|
||||
)
|
||||
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
|
||||
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
|
||||
self._decrypt_buffer()
|
||||
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
|
||||
if session_id == 0:
|
||||
self._handle_channel_comms(message_length, message_type)
|
||||
return
|
||||
|
||||
if session_id not in self.sessions:
|
||||
raise ThpError("Unalloacted session")
|
||||
|
||||
session_state = self.sessions[session_id].get_session_state()
|
||||
if session_state is SessionState.UNALLOCATED:
|
||||
raise ThpError("Unalloacted session")
|
||||
|
||||
self.sessions[session_id].incoming_message.publish(
|
||||
MessageWithType(
|
||||
message_type,
|
||||
self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH],
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_channel_comms(self, message_length: int, message_type: int) -> None:
|
||||
try:
|
||||
buf = self.buffer[INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH]
|
||||
|
||||
expected_type = protobuf.type_for_wire(message_type)
|
||||
message = message_handler.wrap_protobuf_load(buf, expected_type)
|
||||
print(message)
|
||||
# TODO handle other messages than CreateNewSession
|
||||
assert isinstance(message, ThpCreateNewSession)
|
||||
print("passphrase:", message.passphrase)
|
||||
# await thp_messages.handle_CreateNewSession(message)
|
||||
if message.passphrase is not None:
|
||||
self.create_new_session(message.passphrase)
|
||||
else:
|
||||
self.create_new_session()
|
||||
except Exception as e:
|
||||
print("Proč??")
|
||||
print(e)
|
||||
# TODO not finished
|
||||
|
||||
def _decrypt(self, payload) -> bytes:
|
||||
return payload # TODO add decryption process
|
||||
|
Loading…
Reference in New Issue
Block a user