mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): add write_force to allow synchronous writing in ProtocolV2, fix wipe device
[no changelog]
This commit is contained in:
parent
e000a0bde4
commit
f1c411fd35
@ -1,12 +1,19 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.wire.context import get_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import Success, WipeDevice
|
||||
from typing import NoReturn
|
||||
|
||||
from trezor.messages import WipeDevice
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
|
||||
async def wipe_device(msg: WipeDevice) -> Success:
|
||||
async def wipe_device(msg: WipeDevice) -> NoReturn:
|
||||
import storage
|
||||
from trezor import TR, config, translations
|
||||
from trezor import TR, config, loop, translations
|
||||
from trezor.enums import ButtonRequestType
|
||||
from trezor.messages import Success
|
||||
from trezor.pin import render_empty_loader
|
||||
@ -27,8 +34,11 @@ async def wipe_device(msg: WipeDevice) -> Success:
|
||||
)
|
||||
|
||||
# start an empty progress screen so that the screen is not blank while waiting
|
||||
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
|
||||
|
||||
await get_context().write_force(Success(message="Device wiped"))
|
||||
if __debug__:
|
||||
log.debug(__name__, "Device wipe - start")
|
||||
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
|
||||
# wipe storage
|
||||
storage.wipe()
|
||||
# erase translations
|
||||
@ -37,5 +47,7 @@ async def wipe_device(msg: WipeDevice) -> Success:
|
||||
|
||||
# reload settings
|
||||
reload_settings_from_storage()
|
||||
loop.clear()
|
||||
|
||||
return Success(message="Device wiped")
|
||||
if __debug__:
|
||||
log.debug(__name__, "Device wipe - finished")
|
||||
|
@ -51,6 +51,9 @@ class Context:
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None: ...
|
||||
|
||||
async def write_force(self, msg: protobuf.MessageType) -> None:
|
||||
await self.write(msg)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
|
@ -20,7 +20,6 @@ from . import (
|
||||
interface_manager,
|
||||
memory_manager,
|
||||
received_message_handler,
|
||||
session_manager,
|
||||
)
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .thp_messages import ENCRYPTED_TRANSPORT, PacketHeader
|
||||
@ -303,7 +302,12 @@ class Channel:
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
async def write(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
session_id: int = 0,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__,
|
||||
@ -317,7 +321,9 @@ class Channel:
|
||||
noise_payload_len = memory_manager.encode_into_buffer(
|
||||
self.buffer, msg, session_id
|
||||
)
|
||||
return self.write_and_encrypt(self.buffer[:noise_payload_len])
|
||||
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
|
||||
if task is not None:
|
||||
await task
|
||||
|
||||
def write_error(self, err_type: int) -> Awaitable[None]:
|
||||
msg_data = err_type.to_bytes(1, "big")
|
||||
@ -325,7 +331,9 @@ class Channel:
|
||||
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
||||
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
||||
|
||||
def write_and_encrypt(self, payload: bytes) -> None:
|
||||
def write_and_encrypt(
|
||||
self, payload: bytes, force: bool = False
|
||||
) -> Awaitable[None] | None:
|
||||
payload_length = len(payload)
|
||||
self._encrypt(self.buffer, payload_length)
|
||||
payload_length = payload_length + TAG_LENGTH
|
||||
@ -334,11 +342,20 @@ class Channel:
|
||||
self.write_task_spawn.close() # UPS TODO might break something
|
||||
print("\nCLOSED\n")
|
||||
self._prepare_write()
|
||||
if force:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "Writing FORCE message (without async or retransmission)."
|
||||
)
|
||||
return self._write_encrypted_payload_loop(
|
||||
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
|
||||
)
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(
|
||||
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
|
||||
)
|
||||
)
|
||||
return None
|
||||
|
||||
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||
self._prepare_write()
|
||||
|
@ -163,7 +163,7 @@ class PairingContext(Context):
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return self.channel_ctx.write(msg)
|
||||
return await self.channel_ctx.write(msg)
|
||||
|
||||
async def call(
|
||||
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
|
||||
|
@ -124,7 +124,7 @@ async def handle_received_message(
|
||||
)
|
||||
except ThpUnallocatedSessionError as e:
|
||||
error_message = Failure(code=FailureType.ThpUnallocatedSession)
|
||||
ctx.write(error_message, e.session_id)
|
||||
await ctx.write(error_message, e.session_id)
|
||||
except ThpDecryptionError:
|
||||
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
|
||||
ctx.clear()
|
||||
|
@ -138,6 +138,13 @@ class GenericSessionContext(Context):
|
||||
)
|
||||
message: Message = await self.incoming_message.take()
|
||||
if message.type not in expected_types:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"EXPECTED TYPES: %s\nRECEIVED TYPE: %s",
|
||||
str(expected_types),
|
||||
str(message.type),
|
||||
)
|
||||
raise UnexpectedMessageException(message)
|
||||
|
||||
if expected_type is None:
|
||||
@ -146,7 +153,10 @@ class GenericSessionContext(Context):
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return self.channel.write(msg, self.session_id)
|
||||
return await self.channel.write(msg, self.session_id)
|
||||
|
||||
async def write_force(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel.write(msg, self.session_id, force=True)
|
||||
|
||||
def get_session_state(self) -> SessionState: ...
|
||||
|
||||
|
@ -23,9 +23,9 @@ class TransmissionLoop:
|
||||
self.wait_task: loop.spawn | None = None
|
||||
self.min_retransmisson_count_achieved: bool = False
|
||||
|
||||
async def start(self):
|
||||
async def start(self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT):
|
||||
self.min_retransmisson_count_achieved = False
|
||||
for i in range(MAX_RETRANSMISSION_COUNT):
|
||||
for i in range(max_retransmission_count):
|
||||
if i >= MIN_RETRANSMISSION_COUNT:
|
||||
self.min_retransmisson_count_achieved = True
|
||||
await write_payload_to_wire_and_add_checksum(
|
||||
|
Loading…
Reference in New Issue
Block a user