mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 15:30:55 +00:00
chore(core): make await write blocking by default, remove write_force
[no changelog]
This commit is contained in:
parent
c0d46ac762
commit
5a2470e08e
@ -89,7 +89,7 @@ async def reboot_to_bootloader(msg: RebootToBootloader) -> NoReturn:
|
|||||||
boot_args = None
|
boot_args = None
|
||||||
|
|
||||||
ctx = get_context()
|
ctx = get_context()
|
||||||
await ctx.write_force(Success(message="Rebooting"))
|
await ctx.write(Success(message="Rebooting"))
|
||||||
# make sure the outgoing USB buffer is flushed
|
# make sure the outgoing USB buffer is flushed
|
||||||
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
|
await loop.wait(ctx.iface.iface_num() | io.POLL_WRITE)
|
||||||
# reboot to the bootloader, pass the firmware header hash if any
|
# reboot to the bootloader, pass the firmware header hash if any
|
||||||
|
@ -49,7 +49,7 @@ async def wipe_device(msg: WipeDevice) -> NoReturn:
|
|||||||
translations.deinit()
|
translations.deinit()
|
||||||
translations.erase()
|
translations.erase()
|
||||||
|
|
||||||
await get_context().write_force(Success(message="Device wiped"))
|
await get_context().write(Success(message="Device wiped"))
|
||||||
storage.wipe_cache()
|
storage.wipe_cache()
|
||||||
|
|
||||||
# reload settings
|
# reload settings
|
||||||
|
@ -4,7 +4,7 @@ from trezor import protobuf
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezorio import WireInterface
|
from trezorio import WireInterface
|
||||||
from typing import Awaitable, Container, TypeVar, overload
|
from typing import Container, TypeVar, overload
|
||||||
|
|
||||||
from storage.cache_common import DataCache
|
from storage.cache_common import DataCache
|
||||||
|
|
||||||
@ -72,9 +72,6 @@ class Context:
|
|||||||
"""Write a message to the wire."""
|
"""Write a message to the wire."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
|
|
||||||
return self.write(msg)
|
|
||||||
|
|
||||||
async def call(
|
async def call(
|
||||||
self,
|
self,
|
||||||
msg: protobuf.MessageType,
|
msg: protobuf.MessageType,
|
||||||
|
@ -226,12 +226,11 @@ class Channel:
|
|||||||
|
|
||||||
# WRITE and ENCRYPT
|
# WRITE and ENCRYPT
|
||||||
|
|
||||||
async def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
msg: protobuf.MessageType,
|
msg: protobuf.MessageType,
|
||||||
session_id: int = 0,
|
session_id: int = 0,
|
||||||
force: bool = False,
|
) -> Awaitable[None]:
|
||||||
) -> None:
|
|
||||||
if __debug__ and utils.EMULATOR:
|
if __debug__ and utils.EMULATOR:
|
||||||
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
|
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
|
||||||
|
|
||||||
@ -239,9 +238,7 @@ class Channel:
|
|||||||
noise_payload_len = memory_manager.encode_into_buffer(
|
noise_payload_len = memory_manager.encode_into_buffer(
|
||||||
self.buffer, msg, session_id
|
self.buffer, msg, session_id
|
||||||
)
|
)
|
||||||
task = self._write_and_encrypt(self.buffer[:noise_payload_len], force)
|
return self._write_and_encrypt(self.buffer[:noise_payload_len])
|
||||||
if task is not None:
|
|
||||||
await task
|
|
||||||
|
|
||||||
def write_error(self, err_type: int) -> Awaitable[None]:
|
def write_error(self, err_type: int) -> Awaitable[None]:
|
||||||
msg_data = err_type.to_bytes(1, "big")
|
msg_data = err_type.to_bytes(1, "big")
|
||||||
@ -255,9 +252,7 @@ class Channel:
|
|||||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _write_and_encrypt(
|
def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]:
|
||||||
self, payload: bytes, force: bool = False
|
|
||||||
) -> Awaitable[None] | None:
|
|
||||||
payload_length = len(payload)
|
payload_length = len(payload)
|
||||||
self._encrypt(self.buffer, payload_length)
|
self._encrypt(self.buffer, payload_length)
|
||||||
payload_length = payload_length + TAG_LENGTH
|
payload_length = payload_length + TAG_LENGTH
|
||||||
@ -266,19 +261,12 @@ class Channel:
|
|||||||
self.write_task_spawn.close() # UPS TODO might break something
|
self.write_task_spawn.close() # UPS TODO might break something
|
||||||
print("\nCLOSED\n")
|
print("\nCLOSED\n")
|
||||||
self._prepare_write()
|
self._prepare_write()
|
||||||
if force:
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
self._log("Writing FORCE message (without async or retransmission).")
|
|
||||||
|
|
||||||
return self._write_encrypted_payload_loop(
|
|
||||||
ENCRYPTED, memoryview(self.buffer[:payload_length])
|
|
||||||
)
|
|
||||||
self.write_task_spawn = loop.spawn(
|
self.write_task_spawn = loop.spawn(
|
||||||
self._write_encrypted_payload_loop(
|
self._write_encrypted_payload_loop(
|
||||||
ENCRYPTED, memoryview(self.buffer[:payload_length])
|
ENCRYPTED, memoryview(self.buffer[:payload_length])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return None
|
return self.write_task_spawn
|
||||||
|
|
||||||
def _prepare_write(self) -> None:
|
def _prepare_write(self) -> None:
|
||||||
# TODO add condition that disallows to write when can_send_message is false
|
# TODO add condition that disallows to write when can_send_message is false
|
||||||
@ -333,6 +321,7 @@ class Channel:
|
|||||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||||
|
|
||||||
def _can_clear_loop(self) -> bool:
|
def _can_clear_loop(self) -> bool:
|
||||||
|
return False # TODO
|
||||||
return (
|
return (
|
||||||
not workflow.tasks
|
not workflow.tasks
|
||||||
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
||||||
|
@ -180,7 +180,7 @@ async def _handle_ack(ctx: Channel, ack_bit: int) -> None:
|
|||||||
|
|
||||||
ABP.set_sending_allowed(ctx.channel_cache, True)
|
ABP.set_sending_allowed(ctx.channel_cache, True)
|
||||||
|
|
||||||
if ctx.write_task_spawn is not None:
|
if ctx.write_task_spawn is not None and ctx.write_task_spawn.is_running():
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
|
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
|
||||||
await ctx.write_task_spawn
|
await ctx.write_task_spawn
|
||||||
|
@ -119,11 +119,8 @@ class GenericSessionContext(Context):
|
|||||||
|
|
||||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||||
|
|
||||||
async def write(self, msg: protobuf.MessageType) -> None:
|
def write(self, msg: protobuf.MessageType) -> Awaitable[None]:
|
||||||
return await self.channel.write(msg, self.session_id)
|
return self.channel.write(msg, self.session_id)
|
||||||
|
|
||||||
def write_force(self, msg: protobuf.MessageType) -> Awaitable[None]:
|
|
||||||
return self.channel.write(msg, self.session_id, force=True)
|
|
||||||
|
|
||||||
def get_session_state(self) -> SessionState: ...
|
def get_session_state(self) -> SessionState: ...
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user