Create new session response

Fix style
M1nd3r 2 months ago
parent 750c37697e
commit 8f776fcced

@ -41,8 +41,6 @@ MESSAGE_TYPE_LENGTH = const(2)
REPORT_LENGTH = const(64) REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000) MAX_PAYLOAD_LEN = const(60000)
_ACK_MESSAGE = 0x20
class Channel(Context): class Channel(Context):
def __init__(self, channel_cache: ChannelCache) -> None: def __init__(self, channel_cache: ChannelCache) -> None:
@ -163,21 +161,17 @@ class Channel(Context):
log.debug( log.debug(
__name__, "Received message with an unexpected synchronization bit" __name__, "Received message with an unexpected synchronization bit"
) )
await self._sendAck(sync_bit) await self._send_ack(sync_bit)
raise ThpError("Received message with an unexpected synchronization bit") raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response # 3: Send ACK in response
if __debug__: await self._send_ack(sync_bit)
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.channel_id,
sync_bit,
)
await self._sendAck(sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit) THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
self._handle_valid_message(payload_length, message_length, ctrl_byte) await self._handle_valid_message(
payload_length, message_length, ctrl_byte, sync_bit
)
print("end handle completed message") print("end handle completed message")
def _check_checksum(self, message_length: int): def _check_checksum(self, message_length: int):
@ -189,15 +183,16 @@ class Channel(Context):
self._todo_clear_buffer() self._todo_clear_buffer()
raise ThpError("Invalid checksum, ignoring message.") raise ThpError("Invalid checksum, ignoring message.")
def _handle_valid_message( async def _handle_valid_message(
self, payload_length: int, message_length: int, ctrl_byte: int self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
) -> None: ) -> None:
state = self.get_channel_state() state = self.get_channel_state()
if __debug__: if __debug__:
log.debug(__name__, _state_to_str(state)) log.debug(__name__, _state_to_str(state))
if state is ChannelState.TH1: if state is ChannelState.TH1:
self._handle_state_TH1(payload_length, message_length) await self._handle_state_TH1(payload_length, message_length, sync_bit)
return
if not _is_ctrl_byte_encrypted_transport(ctrl_byte): if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
self._todo_clear_buffer() self._todo_clear_buffer()
@ -205,11 +200,15 @@ class Channel(Context):
if state is ChannelState.ENCRYPTED_TRANSPORT: if state is ChannelState.ENCRYPTED_TRANSPORT:
self._handle_state_ENCRYPTED_TRANSPORT(message_length) self._handle_state_ENCRYPTED_TRANSPORT(message_length)
return
if state is ChannelState.TH2: if state is ChannelState.TH2:
self._handle_state_TH2(message_length) await self._handle_state_TH2(message_length, sync_bit)
return
def _handle_state_TH1(self, payload_length: int, message_length: int) -> None: async def _handle_state_TH1(
self, payload_length: int, message_length: int, sync_bit: int
) -> None:
if not _is_ctrl_byte_handshake_init: if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!") raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH: if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
@ -218,8 +217,10 @@ class Channel(Context):
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH] self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
) )
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key) cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
# TODO send ack in response
# TODO send handshake init response message await self._send_ack(sync_bit)
# send handshake init response message
loop.schedule( loop.schedule(
self._write_encrypted_payload_loop( self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response() thp_messages.get_handshake_init_response()
@ -228,7 +229,7 @@ class Channel(Context):
self.set_channel_state(ChannelState.TH2) self.set_channel_state(ChannelState.TH2)
return return
def _handle_state_TH2(self, message_length: int) -> None: async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
print("th2 branche") print("th2 branche")
host_encrypted_static_pubkey = self.buffer[ host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
@ -243,11 +244,11 @@ class Channel(Context):
host_encrypted_static_pubkey, host_encrypted_static_pubkey,
handshake_completion_request_noise_payload, handshake_completion_request_noise_payload,
) # TODO remove ) # TODO remove
# TODO send ack in response
# TODO send hanshake completion response # send hanshake completion response
loop.schedule( loop.schedule(
self._write_encrypted_payload_loop( self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response() thp_messages.get_handshake_completion_response()
) )
) )
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
@ -288,6 +289,11 @@ class Channel(Context):
self.create_new_session(message.passphrase) self.create_new_session(message.passphrase)
else: else:
self.create_new_session() self.create_new_session()
# TODO reuse existing buffer and compute size dynamically
bufferrone = bytearray(2)
message_size: int = thp_messages.get_new_session_message(bufferrone)
print(message_size) # TODO adjust
loop.schedule(self._write_encrypted_payload_loop(bufferrone))
except Exception as e: except Exception as e:
print("Proč??") print("Proč??")
print(e) print(e)
@ -311,13 +317,20 @@ class Channel(Context):
self.expected_payload_length = 0 self.expected_payload_length = 0
self.is_cont_packet_expected = False self.is_cont_packet_expected = False
async def _sendAck(self, ack_bit: int) -> None: async def _send_ack(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit) ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader( header = InitHeader(
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
) )
chksum = checksum.compute(header.to_bytes()) chksum = checksum.compute(header.to_bytes())
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH) if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.channel_id,
ack_bit,
)
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0: if sync_bit == 0:
@ -335,19 +348,22 @@ class Channel(Context):
# trezor.crypto.noise.encode(key, payload=self.buffer) # trezor.crypto.noise.encode(key, payload=self.buffer)
# TODO payload_len should be output from trezor.crypto.noise.encode # TODO payload_len should be output from trezor.crypto.noise.encode, I guess
payload_len = noise_payload_len # + TAG_LENGTH # TODO payload_len = noise_payload_len # + TAG_LENGTH # TODO
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len])) loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
async def _write_encrypted_payload_loop(self, payload: bytes) -> None: async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
print("write loop before while") print("write loop before while")
payload_len = len(payload) payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = THP.sync_get_send_bit(self.channel_cache) sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit) ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
header = InitHeader( header = InitHeader(
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
) )
chksum = checksum.compute(header.to_bytes() + payload)
payload = payload + chksum
# TODO add condition that disallows to write when can_send_message is false # TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False) THP.sync_set_can_send_message(self.channel_cache, False)
while True: while True:
@ -357,7 +373,7 @@ class Channel(Context):
" send_sync_bit:", " send_sync_bit:",
THP.sync_get_send_bit(self.channel_cache), THP.sync_get_send_bit(self.channel_cache),
) )
await self._write_encrypted_payload(header, payload, payload_len) await self._write_payload_to_wire(header, payload, payload_len)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try: try:
await self.waiting_for_ack_timeout await self.waiting_for_ack_timeout
@ -365,28 +381,29 @@ class Channel(Context):
THP.sync_set_send_bit_to_opposite(self.channel_cache) THP.sync_set_send_bit_to_opposite(self.channel_cache)
break break
async def _write_encrypted_payload( async def _write_payload_to_wire(
self, header: InitHeader, payload: bytes, payload_len: int self, header: InitHeader, payload: bytes, payload_len: int
): ):
print("write payload to wire:")
# prepare the report buffer with header data # prepare the report buffer with header data
report = bytearray(REPORT_LENGTH) report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report) header.pack_to_buffer(report)
# write initial report # write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0) nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await self._write_report(report) await self._write_report_to_wire(report)
# if we have more data to write, use continuation reports for it # if we have more data to write, use continuation reports for it
if nwritten < payload_len: if nwritten < payload_len:
header.pack_to_cont_buffer(report) header.pack_to_cont_buffer(report)
while nwritten < payload_len: while nwritten < payload_len:
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten) nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await self._write_report(report) await self._write_report_to_wire(report)
async def _write_report(self, report: utils.BufferType) -> None: async def _write_report_to_wire(self, report: utils.BufferType) -> None:
while True: while True:
await loop.wait(self.iface.iface_num() | io.POLL_WRITE) await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
printBytes(report) # TODO remove
n = self.iface.write(report) n = self.iface.write(report)
if n == len(report): if n == len(report):
return return

@ -2,7 +2,7 @@ import ustruct # pyright:ignore[reportMissingModuleSource]
from storage.cache_thp import BROADCAST_CHANNEL_ID from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf from trezor import protobuf
from trezor.messages import ThpCreateNewSession from trezor.messages import ThpCreateNewSession, ThpNewSession
from .. import message_handler from .. import message_handler
from ..protocol_common import Message from ..protocol_common import Message
@ -15,6 +15,9 @@ ACK_MESSAGE = 0x20
_ERROR = 0x41 _ERROR = 0x41
_CHANNEL_ALLOCATION_RES = 0x40 _CHANNEL_ALLOCATION_RES = 0x40
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
class InitHeader: class InitHeader:
format_str = ">BHH" format_str = ">BHH"
@ -79,7 +82,21 @@ def get_error_unallocated_channel() -> bytes:
def get_handshake_init_response() -> bytes: def get_handshake_init_response() -> bytes:
return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement # TODO implement - 32 bytes ephemeral key, 48 bytes encrypted and masked public key, 16 bytes ciphertext of empty string (i.e. noise tag)
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x40\x41\x42\x43\x44\x45\x46\x47\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
def get_handshake_completion_response() -> bytes:
return (
TREZOR_STATE_PAIRED
+ b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
)
def get_new_session_message(buffer: bytearray) -> int:
msg = ThpNewSession(new_session_id=1)
encoded_msg = protobuf.encode(buffer, msg)
return encoded_msg
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:

Loading…
Cancel
Save