1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-13 09:58:09 +00:00

chore(core): make await write blocking by default, remove write_force

[no changelog]
This commit is contained in:
M1nd3r 2024-12-09 16:51:19 +01:00
parent 084b282d69
commit b9cca4c4f2
6 changed files with 12 additions and 29 deletions

View File

@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
boot_args = None
ctx = get_context()
await ctx.write_force(Success(message="Rebooting"))
await ctx.write(Success(message="Rebooting"))
# make sure the outgoing USB buffer is flushed
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
# reboot to the bootloader, pass the firmware header hash if any

View File

@ -49,7 +49,7 @@ async def wipe_device(msg: WipeDevice) -> NoReturn:
translations.deinit()
translations.erase()
await get_context().write_force(Success(message="Device wiped"))
await get_context().write(Success(message="Device wiped"))
storage.wipe_cache()
# reload settings

View File

@ -4,7 +4,7 @@ from trezor import protobuf
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import Awaitable, Container, TypeVar, overload
from typing import Container, TypeVar, overload
from storage.cache_common import DataCache
@ -72,9 +72,6 @@ class Context:
"""Write a message to the wire."""
...
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.write(msg)
async def call(
self,
msg: protobuf.MessageType,

View File

@ -226,12 +226,11 @@ class Channel:
# WRITE and ENCRYPT
async def write(
def write(
self,
msg: protobuf.MessageType,
session_id: int = 0,
force: bool = False,
) -> None:
) -> Awaitable[None]:
if __debug__ and utils.EMULATOR:
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
@ -239,9 +238,7 @@ class Channel:
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id
)
task = self._write_and_encrypt(self.buffer[:noise_payload_len], force)
if task is not None:
await task
return self._write_and_encrypt(self.buffer[:noise_payload_len])
def write_error(self, err_type: int) -> Awaitable[None]:
msg_data = err_type.to_bytes(1, "big")
@ -255,9 +252,7 @@ class Channel:
self._write_encrypted_payload_loop(ctrl_byte, payload)
)
def _write_and_encrypt(
self, payload: bytes, force: bool = False
) -> Awaitable[None] | None:
def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]:
payload_length = len(payload)
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
@ -266,19 +261,12 @@ class Channel:
self.write_task_spawn.close() # UPS TODO might break something
print("\nCLOSED\n")
self._prepare_write()
if force:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("Writing FORCE message (without async or retransmission).")
return self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
)
return None
return self.write_task_spawn
def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false
@ -333,6 +321,7 @@ class Channel:
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
def _can_clear_loop(self) -> bool:
return False # TODO
return (
not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT

View File

@ -180,7 +180,7 @@ async def _handle_ack(ctx: Channel, ack_bit: int) -> None:
ABP.set_sending_allowed(ctx.channel_cache, True)
if ctx.write_task_spawn is not None:
if ctx.write_task_spawn is not None and ctx.write_task_spawn.is_running():
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn

View File

@ -119,11 +119,8 @@ class GenericSessionContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id)
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.channel.write(msg, self.session_id, force=True)
def write(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.channel.write(msg, self.session_id)
def get_session_state(self) -> SessionState: ...