|
|
|
@ -119,7 +119,6 @@ class Channel(Context):
|
|
|
|
|
else:
|
|
|
|
|
pass
|
|
|
|
|
# TODO use small buffer
|
|
|
|
|
print("self.buffer2")
|
|
|
|
|
try:
|
|
|
|
|
# TODO for now, we create a new big buffer every time. It should be changed
|
|
|
|
|
self.buffer: utils.BufferType = _get_buffer_for_message(
|
|
|
|
@ -128,8 +127,6 @@ class Channel(Context):
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(e)
|
|
|
|
|
print("payload len", payload_length)
|
|
|
|
|
print("self.buffer", self.buffer)
|
|
|
|
|
print("self.buuffer.type", type(self.buffer))
|
|
|
|
|
print("len", len(self.buffer))
|
|
|
|
|
await self._buffer_packet_data(self.buffer, packet, 0)
|
|
|
|
|
print("end init")
|
|
|
|
@ -179,8 +176,10 @@ class Channel(Context):
|
|
|
|
|
)
|
|
|
|
|
# TODO send ack in response
|
|
|
|
|
# TODO send handshake init response message
|
|
|
|
|
await self._write_encrypted_payload_loop(
|
|
|
|
|
thp_messages.get_handshake_init_response()
|
|
|
|
|
loop.schedule(
|
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
|
thp_messages.get_handshake_init_response()
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
self.set_channel_state(ChannelState.TH2)
|
|
|
|
|
return
|
|
|
|
@ -218,11 +217,11 @@ class Channel(Context):
|
|
|
|
|
# TODO not finished
|
|
|
|
|
|
|
|
|
|
if session_id not in self.sessions:
|
|
|
|
|
raise Exception("Unalloacted session")
|
|
|
|
|
raise Exception("Unalloacted session") # TODO send error message
|
|
|
|
|
|
|
|
|
|
session_state = self.sessions[session_id].get_session_state()
|
|
|
|
|
if session_state is SessionState.UNALLOCATED:
|
|
|
|
|
raise Exception("Unalloacted session")
|
|
|
|
|
raise Exception("Unalloacted session") # TODO send error message
|
|
|
|
|
|
|
|
|
|
self.sessions[session_id].incoming_message.publish(
|
|
|
|
|
MessageWithType(
|
|
|
|
@ -245,8 +244,13 @@ class Channel(Context):
|
|
|
|
|
) # TODO remove
|
|
|
|
|
# TODO send ack in response
|
|
|
|
|
# TODO send hanshake completion response
|
|
|
|
|
loop.schedule(
|
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
|
thp_messages.get_handshake_init_response()
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
|
|
|
|
print("end completed message")
|
|
|
|
|
print("end handle completed message")
|
|
|
|
|
|
|
|
|
|
def _decrypt(self, payload) -> bytes:
|
|
|
|
|
return payload # TODO add decryption process
|
|
|
|
@ -269,6 +273,7 @@ class Channel(Context):
|
|
|
|
|
# CALLED BY WORKFLOW / SESSION CONTEXT
|
|
|
|
|
|
|
|
|
|
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
|
|
|
|
print("write")
|
|
|
|
|
|
|
|
|
|
noise_payload_len = self._encode_into_buffer(msg, session_id)
|
|
|
|
|
|
|
|
|
@ -277,15 +282,15 @@ class Channel(Context):
|
|
|
|
|
# TODO payload_len should be output from trezor.crypto.noise.encode
|
|
|
|
|
payload_len = noise_payload_len # + TAG_LENGTH # TODO
|
|
|
|
|
|
|
|
|
|
await self._write_encrypted_payload_loop(self.buffer[:payload_len])
|
|
|
|
|
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
|
|
|
|
|
|
|
|
|
|
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
|
|
|
|
|
|
|
|
|
print("write loop before while")
|
|
|
|
|
payload_len = len(payload)
|
|
|
|
|
header = InitHeader(
|
|
|
|
|
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
THP.sync_set_can_send_message(self.channel_cache, False)
|
|
|
|
|
while True:
|
|
|
|
|
print("write encrypted payload loop - start")
|
|
|
|
|
await self._write_encrypted_payload(header, payload, payload_len)
|
|
|
|
@ -323,7 +328,6 @@ class Channel(Context):
|
|
|
|
|
|
|
|
|
|
async def _wait_for_ack(self) -> None:
|
|
|
|
|
await loop.sleep(1000)
|
|
|
|
|
# TODO retry write
|
|
|
|
|
|
|
|
|
|
def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int:
|
|
|
|
|
|
|
|
|
@ -356,7 +360,6 @@ class Channel(Context):
|
|
|
|
|
from trezor.wire.thp.session_context import SessionContext
|
|
|
|
|
|
|
|
|
|
session = SessionContext.create_new_session(self)
|
|
|
|
|
print("help")
|
|
|
|
|
self.sessions[session.session_id] = session
|
|
|
|
|
loop.schedule(session.handle())
|
|
|
|
|
print("new session created. Session id:", session.session_id)
|
|
|
|
@ -368,7 +371,7 @@ class Channel(Context):
|
|
|
|
|
# TODO add debug logging to ACK handling
|
|
|
|
|
def _handle_received_ACK(self, sync_bit: int) -> None:
|
|
|
|
|
if self._ack_is_not_expected():
|
|
|
|
|
print("ack is not expeccted")
|
|
|
|
|
print("ack is not expected")
|
|
|
|
|
return
|
|
|
|
|
if self._ack_has_incorrect_sync_bit(sync_bit):
|
|
|
|
|
print("ack has incorrect sync bit")
|
|
|
|
|