1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 15:30:55 +00:00

chore(core): make await write blocking by default, remove write_force

[no changelog]
This commit is contained in:
M1nd3r 2024-12-09 16:51:19 +01:00
parent c0d46ac762
commit 5a2470e08e
6 changed files with 12 additions and 29 deletions

View File

@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
boot_args = None boot_args = None
ctx = get_context() ctx = get_context()
await ctx.write_force(Success(message="Rebooting")) await ctx.write(Success(message="Rebooting"))
# make sure the outgoing USB buffer is flushed # make sure the outgoing USB buffer is flushed
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE) await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
# reboot to the bootloader, pass the firmware header hash if any # reboot to the bootloader, pass the firmware header hash if any

View File

@ -49,7 +49,7 @@ async def wipe_device(msg: WipeDevice) -> NoReturn:
translations.deinit() translations.deinit()
translations.erase() translations.erase()
await get_context().write_force(Success(message="Device wiped")) await get_context().write(Success(message="Device wiped"))
storage.wipe_cache() storage.wipe_cache()
# reload settings # reload settings

View File

@ -4,7 +4,7 @@ from trezor import protobuf
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface from trezorio import WireInterface
from typing import Awaitable, Container, TypeVar, overload from typing import Container, TypeVar, overload
from storage.cache_common import DataCache from storage.cache_common import DataCache
@ -72,9 +72,6 @@ class Context:
"""Write a message to the wire.""" """Write a message to the wire."""
... ...
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.write(msg)
async def call( async def call(
self, self,
msg: protobuf.MessageType, msg: protobuf.MessageType,

View File

@ -226,12 +226,11 @@ class Channel:
# WRITE and ENCRYPT # WRITE and ENCRYPT
async def write( def write(
self, self,
msg: protobuf.MessageType, msg: protobuf.MessageType,
session_id: int = 0, session_id: int = 0,
force: bool = False, ) -> Awaitable[None]:
) -> None:
if __debug__ and utils.EMULATOR: if __debug__ and utils.EMULATOR:
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg)) self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
@ -239,9 +238,7 @@ class Channel:
noise_payload_len = memory_manager.encode_into_buffer( noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id self.buffer, msg, session_id
) )
task = self._write_and_encrypt(self.buffer[:noise_payload_len], force) return self._write_and_encrypt(self.buffer[:noise_payload_len])
if task is not None:
await task
def write_error(self, err_type: int) -> Awaitable[None]: def write_error(self, err_type: int) -> Awaitable[None]:
msg_data = err_type.to_bytes(1, "big") msg_data = err_type.to_bytes(1, "big")
@ -255,9 +252,7 @@ class Channel:
self._write_encrypted_payload_loop(ctrl_byte, payload) self._write_encrypted_payload_loop(ctrl_byte, payload)
) )
def _write_and_encrypt( def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]:
self, payload: bytes, force: bool = False
) -> Awaitable[None] | None:
payload_length = len(payload) payload_length = len(payload)
self._encrypt(self.buffer, payload_length) self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH payload_length = payload_length + TAG_LENGTH
@ -266,19 +261,12 @@ class Channel:
self.write_task_spawn.close() # UPS TODO might break something self.write_task_spawn.close() # UPS TODO might break something
print("\nCLOSED\n") print("\nCLOSED\n")
self._prepare_write() self._prepare_write()
if force:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("Writing FORCE message (without async or retransmission).")
return self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length])
)
self.write_task_spawn = loop.spawn( self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop( self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length]) ENCRYPTED, memoryview(self.buffer[:payload_length])
) )
) )
return None return self.write_task_spawn
def _prepare_write(self) -> None: def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false # TODO add condition that disallows to write when can_send_message is false
@ -333,6 +321,7 @@ class Channel:
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
def _can_clear_loop(self) -> bool: def _can_clear_loop(self) -> bool:
return False # TODO
return ( return (
not workflow.tasks not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT

View File

@ -180,7 +180,7 @@ async def _handle_ack(ctx: Channel, ack_bit: int) -> None:
ABP.set_sending_allowed(ctx.channel_cache, True) ABP.set_sending_allowed(ctx.channel_cache, True)
if ctx.write_task_spawn is not None: if ctx.write_task_spawn is not None and ctx.write_task_spawn.is_running():
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await ctx.write_task_spawn await ctx.write_task_spawn

View File

@ -119,11 +119,8 @@ class GenericSessionContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type) return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None: def write(self, msg: protobuf.MessageType) -> Awaitable[None]:
return await self.channel.write(msg, self.session_id) return self.channel.write(msg, self.session_id)
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
return self.channel.write(msg, self.session_id, force=True)
def get_session_state(self) -> SessionState: ... def get_session_state(self) -> SessionState: ...