diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index e38828fb7..35bacc464 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -119,7 +119,6 @@ class Channel(Context): else: pass # TODO use small buffer - print("self.buffer2") try: # TODO for now, we create a new big buffer every time. It should be changed self.buffer: utils.BufferType = _get_buffer_for_message( @@ -128,8 +127,6 @@ class Channel(Context): except Exception as e: print(e) print("payload len", payload_length) - print("self.buffer", self.buffer) - print("self.buuffer.type", type(self.buffer)) print("len", len(self.buffer)) await self._buffer_packet_data(self.buffer, packet, 0) print("end init") @@ -179,8 +176,10 @@ class Channel(Context): ) # TODO send ack in response # TODO send handshake init response message - await self._write_encrypted_payload_loop( - thp_messages.get_handshake_init_response() + loop.schedule( + self._write_encrypted_payload_loop( + thp_messages.get_handshake_init_response() + ) ) self.set_channel_state(ChannelState.TH2) return @@ -218,11 +217,11 @@ class Channel(Context): # TODO not finished if session_id not in self.sessions: - raise Exception("Unalloacted session") + raise Exception("Unalloacted session") # TODO send error message session_state = self.sessions[session_id].get_session_state() if session_state is SessionState.UNALLOCATED: - raise Exception("Unalloacted session") + raise Exception("Unalloacted session") # TODO send error message self.sessions[session_id].incoming_message.publish( MessageWithType( @@ -245,8 +244,13 @@ class Channel(Context): ) # TODO remove # TODO send ack in response # TODO send hanshake completion response + loop.schedule( + self._write_encrypted_payload_loop( + thp_messages.get_handshake_init_response() + ) + ) self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) - print("end completed message") + print("end handle completed message") def _decrypt(self, payload) -> bytes: return payload # TODO add decryption process @@ -269,6 +273,7 @@ class Channel(Context): # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: + print("write") noise_payload_len = self._encode_into_buffer(msg, session_id) @@ -277,15 +282,15 @@ class Channel(Context): # TODO payload_len should be output from trezor.crypto.noise.encode payload_len = noise_payload_len # + TAG_LENGTH # TODO - await self._write_encrypted_payload_loop(self.buffer[:payload_len]) + loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len])) async def _write_encrypted_payload_loop(self, payload: bytes) -> None: - + print("write loop before while") payload_len = len(payload) header = InitHeader( ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len ) - + THP.sync_set_can_send_message(self.channel_cache, False) while True: print("write encrypted payload loop - start") await self._write_encrypted_payload(header, payload, payload_len) @@ -323,7 +328,6 @@ class Channel(Context): async def _wait_for_ack(self) -> None: await loop.sleep(1000) - # TODO retry write def _encode_into_buffer(self, msg: protobuf.MessageType, session_id: int) -> int: @@ -356,7 +360,6 @@ class Channel(Context): from trezor.wire.thp.session_context import SessionContext session = SessionContext.create_new_session(self) - print("help") self.sessions[session.session_id] = session loop.schedule(session.handle()) print("new session created. Session id:", session.session_id) @@ -368,7 +371,7 @@ class Channel(Context): # TODO add debug logging to ACK handling def _handle_received_ACK(self, sync_bit: int) -> None: if self._ack_is_not_expected(): - print("ack is not expeccted") + print("ack is not expected") return if self._ack_has_incorrect_sync_bit(sync_bit): print("ack has incorrect sync bit") diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 8071a8af4..4a48c38ae 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -44,7 +44,6 @@ _CHANNEL_CONTEXTS: dict[int, Channel] = {} def set_buffer(buffer): global _BUFFER _BUFFER = buffer - print("setbuffer,", type(_BUFFER)) async def thp_main_loop(iface: WireInterface, is_debug_session=False):