1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): add write_force to allow synchronous writing in ProtocolV2, fix wipe device

[no changelog]
This commit is contained in:
M1nd3r 2024-10-02 12:42:18 +02:00
parent e000a0bde4
commit f1c411fd35
7 changed files with 56 additions and 14 deletions

View File

@ -1,12 +1,19 @@
from typing import TYPE_CHECKING
from trezor.wire.context import get_context
if TYPE_CHECKING:
from trezor.messages import Success, WipeDevice
from typing import NoReturn
from trezor.messages import WipeDevice
if __debug__:
from trezor import log
async def wipe_device(msg: WipeDevice) -> Success:
async def wipe_device(msg: WipeDevice) -> NoReturn:
import storage
from trezor import TR, config, translations
from trezor import TR, config, loop, translations
from trezor.enums import ButtonRequestType
from trezor.messages import Success
from trezor.pin import render_empty_loader
@ -27,8 +34,11 @@ async def wipe_device(msg: WipeDevice) -> Success:
)
# start an empty progress screen so that the screen is not blank while waiting
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
await get_context().write_force(Success(message="Device wiped"))
if __debug__:
log.debug(__name__, "Device wipe - start")
render_empty_loader(config.StorageMessage.PROCESSING_MSG)
# wipe storage
storage.wipe()
# erase translations
@ -37,5 +47,7 @@ async def wipe_device(msg: WipeDevice) -> Success:
# reload settings
reload_settings_from_storage()
loop.clear()
return Success(message="Device wiped")
if __debug__:
log.debug(__name__, "Device wipe - finished")

View File

@ -51,6 +51,9 @@ class Context:
async def write(self, msg: protobuf.MessageType) -> None: ...
async def write_force(self, msg: protobuf.MessageType) -> None:
await self.write(msg)
async def call(
self,
msg: protobuf.MessageType,

View File

@ -20,7 +20,6 @@ from . import (
interface_manager,
memory_manager,
received_message_handler,
session_manager,
)
from .checksum import CHECKSUM_LENGTH
from .thp_messages import ENCRYPTED_TRANSPORT, PacketHeader
@ -303,7 +302,12 @@ class Channel:
# CALLED BY WORKFLOW / SESSION CONTEXT
def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
async def write(
self,
msg: protobuf.MessageType,
session_id: int = 0,
force: bool = False,
) -> None:
if __debug__ and utils.EMULATOR:
log.debug(
__name__,
@ -317,7 +321,9 @@ class Channel:
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id
)
return self.write_and_encrypt(self.buffer[:noise_payload_len])
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
if task is not None:
await task
def write_error(self, err_type: int) -> Awaitable[None]:
msg_data = err_type.to_bytes(1, "big")
@ -325,7 +331,9 @@ class Channel:
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
def write_and_encrypt(self, payload: bytes) -> None:
def write_and_encrypt(
self, payload: bytes, force: bool = False
) -> Awaitable[None] | None:
payload_length = len(payload)
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
@ -334,11 +342,20 @@ class Channel:
self.write_task_spawn.close() # UPS TODO might break something
print("\nCLOSED\n")
self._prepare_write()
if force:
if __debug__:
log.debug(
__name__, "Writing FORCE message (without async or retransmission)."
)
return self._write_encrypted_payload_loop(
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
)
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
)
)
return None
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
self._prepare_write()

View File

@ -163,7 +163,7 @@ class PairingContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return self.channel_ctx.write(msg)
return await self.channel_ctx.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]

View File

@ -124,7 +124,7 @@ async def handle_received_message(
)
except ThpUnallocatedSessionError as e:
error_message = Failure(code=FailureType.ThpUnallocatedSession)
ctx.write(error_message, e.session_id)
await ctx.write(error_message, e.session_id)
except ThpDecryptionError:
await ctx.write_error(ThpErrorType.DECRYPTION_FAILED)
ctx.clear()

View File

@ -138,6 +138,13 @@ class GenericSessionContext(Context):
)
message: Message = await self.incoming_message.take()
if message.type not in expected_types:
if __debug__:
log.debug(
__name__,
"EXPECTED TYPES: %s\nRECEIVED TYPE: %s",
str(expected_types),
str(message.type),
)
raise UnexpectedMessageException(message)
if expected_type is None:
@ -146,7 +153,10 @@ class GenericSessionContext(Context):
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return self.channel.write(msg, self.session_id)
return await self.channel.write(msg, self.session_id)
async def write_force(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg, self.session_id, force=True)
def get_session_state(self) -> SessionState: ...

View File

@ -23,9 +23,9 @@ class TransmissionLoop:
self.wait_task: loop.spawn | None = None
self.min_retransmisson_count_achieved: bool = False
async def start(self):
async def start(self, max_retransmission_count: int = MAX_RETRANSMISSION_COUNT):
self.min_retransmisson_count_achieved = False
for i in range(MAX_RETRANSMISSION_COUNT):
for i in range(max_retransmission_count):
if i >= MIN_RETRANSMISSION_COUNT:
self.min_retransmisson_count_achieved = True
await write_payload_to_wire_and_add_checksum(