diff --git a/core/src/apps/management/reboot_to_bootloader.py b/core/src/apps/management/reboot_to_bootloader.py index 2213d2c17a..85596c0268 100644 --- a/core/src/apps/management/reboot_to_bootloader.py +++ b/core/src/apps/management/reboot_to_bootloader.py @@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn: boot_args = None ctx = get_context() - await ctx.write_force(Success(message="Rebooting")) + await ctx.write(Success(message="Rebooting")) # make sure the outgoing USB buffer is flushed await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) # reboot to the bootloader, pass the firmware header hash if any diff --git a/core/src/apps/management/wipe_device.py b/core/src/apps/management/wipe_device.py index 5a08bf1b18..bdc076c4de 100644 --- a/core/src/apps/management/wipe_device.py +++ b/core/src/apps/management/wipe_device.py @@ -49,7 +49,7 @@ async def wipe_device(msg: WipeDevice) -> NoReturn: translations.deinit() translations.erase() - await get_context().write_force(Success(message="Device wiped")) + await get_context().write(Success(message="Device wiped")) storage.wipe_cache() # reload settings diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index 0e54afe8c3..ed4105517b 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -4,7 +4,7 @@ from trezor import protobuf if TYPE_CHECKING: from trezorio import WireInterface - from typing import Awaitable, Container, TypeVar, overload + from typing import Container, TypeVar, overload from storage.cache_common import DataCache @@ -72,9 +72,6 @@ class Context: """Write a message to the wire.""" ... - def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: - return self.write(msg) - async def call( self, msg: protobuf.MessageType, diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index d2ddcbf96e..8a604f43e8 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -226,12 +226,11 @@ class Channel: # WRITE and ENCRYPT - async def write( + def write( self, msg: protobuf.MessageType, session_id: int = 0, - force: bool = False, - ) -> None: + ) -> Awaitable[None]: if __debug__ and utils.EMULATOR: self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg)) @@ -239,9 +238,7 @@ class Channel: noise_payload_len = memory_manager.encode_into_buffer( self.buffer, msg, session_id ) - task = self._write_and_encrypt(self.buffer[:noise_payload_len], force) - if task is not None: - await task + return self._write_and_encrypt(self.buffer[:noise_payload_len]) def write_error(self, err_type: int) -> Awaitable[None]: msg_data = err_type.to_bytes(1, "big") @@ -255,9 +252,7 @@ class Channel: self._write_encrypted_payload_loop(ctrl_byte, payload) ) - def _write_and_encrypt( - self, payload: bytes, force: bool = False - ) -> Awaitable[None] | None: + def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]: payload_length = len(payload) self._encrypt(self.buffer, payload_length) payload_length = payload_length + TAG_LENGTH @@ -266,19 +261,12 @@ class Channel: self.write_task_spawn.close() # UPS TODO might break something print("\nCLOSED\n") self._prepare_write() - if force: - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - self._log("Writing FORCE message (without async or retransmission).") - - return self._write_encrypted_payload_loop( - ENCRYPTED, memoryview(self.buffer[:payload_length]) - ) self.write_task_spawn = loop.spawn( self._write_encrypted_payload_loop( ENCRYPTED, memoryview(self.buffer[:payload_length]) ) ) - return None + return self.write_task_spawn def _prepare_write(self) -> None: # TODO add condition that disallows to write when can_send_message is false @@ -333,6 +321,7 @@ class Channel: buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag def _can_clear_loop(self) -> bool: + return False # TODO return ( not workflow.tasks ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 2a7fbbdf30..23edecba41 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -180,7 +180,7 @@ async def _handle_ack(ctx: Channel, ack_bit: int) -> None: ABP.set_sending_allowed(ctx.channel_cache, True) - if ctx.write_task_spawn is not None: + if ctx.write_task_spawn is not None and ctx.write_task_spawn.is_running(): if __debug__ and utils.ALLOW_DEBUG_MESSAGES: log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') await ctx.write_task_spawn diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 4520738757..d694269002 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -119,11 +119,8 @@ class GenericSessionContext(Context): return message_handler.wrap_protobuf_load(message.data, expected_type) - async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel.write(msg, self.session_id) - - def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]: - return self.channel.write(msg, self.session_id, force=True) + def write(self, msg: protobuf.MessageType) -> Awaitable[None]: + return self.channel.write(msg, self.session_id) def get_session_state(self) -> SessionState: ...