mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-02 14:18:59 +00:00
chore: improve fallback
This commit is contained in:
parent
72b8b8ef1c
commit
2a70914ce1
@ -60,17 +60,18 @@ class Channel:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(__name__, "channel initialization")
|
|
||||||
|
|
||||||
# Channel properties
|
# Channel properties
|
||||||
|
self.channel_id: bytes = channel_cache.channel_id
|
||||||
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
|
self._log("channel initialization")
|
||||||
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
|
self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface)
|
||||||
self.channel_cache: ChannelCache = channel_cache
|
self.channel_cache: ChannelCache = channel_cache
|
||||||
self.channel_id: bytes = channel_cache.channel_id
|
|
||||||
|
|
||||||
# Shared variables
|
# Shared variables
|
||||||
self.buffer: utils.BufferType = bytearray(self.iface.TX_PACKET_LEN)
|
self.buffer: utils.BufferType = bytearray(self.iface.TX_PACKET_LEN)
|
||||||
self.fallback_decrypt: bool = False
|
self.fallback_decrypt: bool = False
|
||||||
|
self.fallback_session_id: int | None = None
|
||||||
self.bytes_read: int = 0
|
self.bytes_read: int = 0
|
||||||
self.expected_payload_length: int = 0
|
self.expected_payload_length: int = 0
|
||||||
self.is_cont_packet_expected: bool = False
|
self.is_cont_packet_expected: bool = False
|
||||||
@ -139,7 +140,9 @@ class Channel:
|
|||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("receive packet")
|
self._log("receive packet")
|
||||||
|
|
||||||
self._handle_received_packet(packet)
|
task = self._handle_received_packet(packet)
|
||||||
|
if task is not None:
|
||||||
|
return task
|
||||||
|
|
||||||
if self.expected_payload_length == 0: # Reading failed TODO
|
if self.expected_payload_length == 0: # Reading failed TODO
|
||||||
from trezor.wire.thp import ThpErrorType
|
from trezor.wire.thp import ThpErrorType
|
||||||
@ -148,13 +151,24 @@ class Channel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||||
except WireBufferError:
|
|
||||||
pass # TODO ??
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
try:
|
|
||||||
self._log("self.buffer: ", get_bytes_as_str(buffer))
|
self._log("self.buffer: ", get_bytes_as_str(buffer))
|
||||||
except Exception:
|
except WireBufferError:
|
||||||
pass # TODO handle nicer - happens in fallback_decrypt
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
|
self._log(
|
||||||
|
"getting read buffer failed - ", str(WireBufferError.__name__)
|
||||||
|
)
|
||||||
|
pass # TODO ??
|
||||||
|
if self.fallback_decrypt and self.expected_payload_length == self.bytes_read:
|
||||||
|
self._finish_fallback()
|
||||||
|
from trezor.messages import Failure
|
||||||
|
from trezor.enums import FailureType
|
||||||
|
|
||||||
|
return self.write(
|
||||||
|
Failure(code=FailureType.DeviceIsBusy, message="FALLBACK!"),
|
||||||
|
session_id=self.fallback_session_id or 0,
|
||||||
|
fallback=True,
|
||||||
|
)
|
||||||
|
|
||||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||||
self._finish_message()
|
self._finish_message()
|
||||||
@ -166,21 +180,30 @@ class Channel:
|
|||||||
return received_message_handler.handle_received_message(self, buffer)
|
return received_message_handler.handle_received_message(self, buffer)
|
||||||
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
||||||
self.is_cont_packet_expected = True
|
self.is_cont_packet_expected = True
|
||||||
|
self._log(
|
||||||
|
"CONT EXPECTED - read/expected:",
|
||||||
|
str(self.bytes_read)
|
||||||
|
+ "/"
|
||||||
|
+ str(self.expected_payload_length + INIT_HEADER_LENGTH),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ThpError(
|
raise ThpError(
|
||||||
"Read more bytes than is the expected length of the message!"
|
"Read more bytes than is the expected length of the message!"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_received_packet(self, packet: utils.BufferType) -> None:
|
def _handle_received_packet(
|
||||||
|
self, packet: utils.BufferType
|
||||||
|
) -> Awaitable[None] | None:
|
||||||
ctrl_byte = packet[0]
|
ctrl_byte = packet[0]
|
||||||
if control_byte.is_continuation(ctrl_byte):
|
if control_byte.is_continuation(ctrl_byte):
|
||||||
self._handle_cont_packet(packet)
|
self._handle_cont_packet(packet)
|
||||||
return
|
return None
|
||||||
self._handle_init_packet(packet)
|
return self._handle_init_packet(packet)
|
||||||
|
|
||||||
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
def _handle_init_packet(self, packet: utils.BufferType) -> None:
|
||||||
self.fallback_decrypt = False
|
self.fallback_decrypt = False
|
||||||
|
self.fallback_session_id = None
|
||||||
self.bytes_read = 0
|
self.bytes_read = 0
|
||||||
self.expected_payload_length = 0
|
self.expected_payload_length = 0
|
||||||
|
|
||||||
@ -204,11 +227,15 @@ class Channel:
|
|||||||
try:
|
try:
|
||||||
buffer = memory_manager.get_new_read_buffer(cid, length)
|
buffer = memory_manager.get_new_read_buffer(cid, length)
|
||||||
except WireBufferError:
|
except WireBufferError:
|
||||||
|
self.fallback_decrypt = True
|
||||||
# TODO handle not encrypted/(short??), eg. ACK
|
# TODO handle not encrypted/(short??), eg. ACK
|
||||||
|
|
||||||
self.fallback_decrypt = True
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if not self._can_fallback():
|
||||||
|
raise Exception(
|
||||||
|
"Channel is in a state that does not support fallback."
|
||||||
|
)
|
||||||
|
self._log("Started fallback read")
|
||||||
self._prepare_fallback()
|
self._prepare_fallback()
|
||||||
except Exception:
|
except Exception:
|
||||||
self.fallback_decrypt = False
|
self.fallback_decrypt = False
|
||||||
@ -220,7 +247,7 @@ class Channel:
|
|||||||
log.debug(
|
log.debug(
|
||||||
__name__, "FAILED TO FALLBACK: %s", hexlify(packet).decode()
|
__name__, "FAILED TO FALLBACK: %s", hexlify(packet).decode()
|
||||||
)
|
)
|
||||||
return
|
return None
|
||||||
|
|
||||||
to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length)
|
to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length)
|
||||||
buf = memoryview(self.buffer)[:to_read_len]
|
buf = memoryview(self.buffer)[:to_read_len]
|
||||||
@ -229,17 +256,23 @@ class Channel:
|
|||||||
# CRC CHECK
|
# CRC CHECK
|
||||||
self._handle_fallback_crc(buf)
|
self._handle_fallback_crc(buf)
|
||||||
|
|
||||||
|
# Handle ACK
|
||||||
|
if control_byte.is_ack(packet[0]):
|
||||||
|
ack_bit = (packet[0] & 0x08) >> 3
|
||||||
|
return received_message_handler._handle_ack(self, ack_bit)
|
||||||
|
|
||||||
# TAG CHECK
|
# TAG CHECK
|
||||||
self._handle_fallback_decryption(buf)
|
self._handle_fallback_decryption(buf)
|
||||||
|
|
||||||
self.bytes_read += to_read_len
|
self.bytes_read += to_read_len
|
||||||
return
|
return None
|
||||||
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("handle_init_packet - payload len: ", str(payload_length))
|
self._log("handle_init_packet - payload len: ", str(payload_length))
|
||||||
self._log("handle_init_packet - buffer len: ", str(len(buffer)))
|
self._log("handle_init_packet - buffer len: ", str(len(buffer)))
|
||||||
|
|
||||||
self._buffer_packet_data(buffer, packet, 0)
|
self._buffer_packet_data(buffer, packet, 0)
|
||||||
|
return None
|
||||||
|
|
||||||
def _handle_fallback_crc(self, buf: memoryview) -> None:
|
def _handle_fallback_crc(self, buf: memoryview) -> None:
|
||||||
assert self.temp_crc is not None
|
assert self.temp_crc is not None
|
||||||
@ -303,6 +336,8 @@ class Channel:
|
|||||||
utils.memcpy(self.temp_tag, offset, noise_tag, 0)
|
utils.memcpy(self.temp_tag, offset, noise_tag, 0)
|
||||||
else:
|
else:
|
||||||
raise Exception("Buffer (+bytes_read) should not be bigger than payload")
|
raise Exception("Buffer (+bytes_read) should not be bigger than payload")
|
||||||
|
if self.fallback_session_id is None:
|
||||||
|
self.fallback_session_id = buf[0]
|
||||||
|
|
||||||
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
@ -347,6 +382,7 @@ class Channel:
|
|||||||
def _finish_fallback(self) -> None:
|
def _finish_fallback(self) -> None:
|
||||||
self.fallback_decrypt = False
|
self.fallback_decrypt = False
|
||||||
self.busy_decoder = None
|
self.busy_decoder = None
|
||||||
|
self._log("Finish fallback")
|
||||||
|
|
||||||
def _decrypt_single_packet_payload(
|
def _decrypt_single_packet_payload(
|
||||||
self, payload: utils.BufferType
|
self, payload: utils.BufferType
|
||||||
@ -419,6 +455,7 @@ class Channel:
|
|||||||
msg: protobuf.MessageType,
|
msg: protobuf.MessageType,
|
||||||
session_id: int = 0,
|
session_id: int = 0,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
|
fallback: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if __debug__ and utils.EMULATOR:
|
if __debug__ and utils.EMULATOR:
|
||||||
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
|
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
|
||||||
@ -428,6 +465,9 @@ class Channel:
|
|||||||
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
||||||
length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + INIT_HEADER_LENGTH
|
length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + INIT_HEADER_LENGTH
|
||||||
try:
|
try:
|
||||||
|
if fallback:
|
||||||
|
buffer = self.buffer
|
||||||
|
else:
|
||||||
buffer = memory_manager.get_new_write_buffer(cid, length)
|
buffer = memory_manager.get_new_write_buffer(cid, length)
|
||||||
noise_payload_len = memory_manager.encode_into_buffer(
|
noise_payload_len = memory_manager.encode_into_buffer(
|
||||||
buffer, msg, session_id
|
buffer, msg, session_id
|
||||||
@ -448,7 +488,9 @@ class Channel:
|
|||||||
session_id,
|
session_id,
|
||||||
)
|
)
|
||||||
self.set_channel_state(ChannelState.INVALIDATED)
|
self.set_channel_state(ChannelState.INVALIDATED)
|
||||||
task = self._write_and_encrypt(noise_payload_len, force)
|
task = self._write_and_encrypt(
|
||||||
|
noise_payload_len=noise_payload_len, force=force, fallback=fallback
|
||||||
|
)
|
||||||
if task is not None:
|
if task is not None:
|
||||||
await task
|
await task
|
||||||
|
|
||||||
@ -465,8 +507,14 @@ class Channel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _write_and_encrypt(
|
def _write_and_encrypt(
|
||||||
self, noise_payload_len: int, force: bool = False
|
self,
|
||||||
|
noise_payload_len: int,
|
||||||
|
force: bool = False,
|
||||||
|
fallback: bool = False,
|
||||||
) -> Awaitable[None] | None:
|
) -> Awaitable[None] | None:
|
||||||
|
if fallback:
|
||||||
|
buffer = self.buffer
|
||||||
|
else:
|
||||||
buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
|
buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
|
||||||
# if buffer is WireBufferError:
|
# if buffer is WireBufferError:
|
||||||
# pass # TODO handle deviceBUSY
|
# pass # TODO handle deviceBUSY
|
||||||
@ -478,6 +526,18 @@ class Channel:
|
|||||||
self.write_task_spawn.close() # UPS TODO might break something
|
self.write_task_spawn.close() # UPS TODO might break something
|
||||||
print("\nCLOSED\n")
|
print("\nCLOSED\n")
|
||||||
self._prepare_write()
|
self._prepare_write()
|
||||||
|
if fallback:
|
||||||
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
|
self._log(
|
||||||
|
"Writing FALLBACK message (written only once without async or retransmission)."
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._write_encrypted_payload_loop(
|
||||||
|
ctrl_byte=ENCRYPTED,
|
||||||
|
payload=memoryview(buffer[:payload_length]),
|
||||||
|
only_once=True,
|
||||||
|
)
|
||||||
|
|
||||||
if force:
|
if force:
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("Writing FORCE message (without async or retransmission).")
|
self._log("Writing FORCE message (without async or retransmission).")
|
||||||
@ -497,7 +557,7 @@ class Channel:
|
|||||||
ABP.set_sending_allowed(self.channel_cache, False)
|
ABP.set_sending_allowed(self.channel_cache, False)
|
||||||
|
|
||||||
async def _write_encrypted_payload_loop(
|
async def _write_encrypted_payload_loop(
|
||||||
self, ctrl_byte: int, payload: bytes
|
self, ctrl_byte: int, payload: bytes, only_once: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("write_encrypted_payload_loop")
|
self._log("write_encrypted_payload_loop")
|
||||||
@ -507,6 +567,9 @@ class Channel:
|
|||||||
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
|
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(ctrl_byte, sync_bit)
|
||||||
header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
|
header = PacketHeader(ctrl_byte, self.get_channel_id_int(), payload_len)
|
||||||
self.transmission_loop = TransmissionLoop(self, header, payload)
|
self.transmission_loop = TransmissionLoop(self, header, payload)
|
||||||
|
if only_once:
|
||||||
|
await self.transmission_loop.start(max_retransmission_count=1)
|
||||||
|
else:
|
||||||
await self.transmission_loop.start()
|
await self.transmission_loop.start()
|
||||||
|
|
||||||
ABP.set_send_seq_bit_to_opposite(self.channel_cache)
|
ABP.set_send_seq_bit_to_opposite(self.channel_cache)
|
||||||
@ -516,7 +579,7 @@ class Channel:
|
|||||||
if self._can_clear_loop():
|
if self._can_clear_loop():
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("clearing loop from channel")
|
self._log("clearing loop from channel")
|
||||||
|
pass
|
||||||
loop.clear()
|
loop.clear()
|
||||||
|
|
||||||
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
||||||
@ -549,6 +612,14 @@ class Channel:
|
|||||||
not workflow.tasks
|
not workflow.tasks
|
||||||
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
|
||||||
|
|
||||||
|
def _can_fallback(self) -> bool:
|
||||||
|
state = self.get_channel_state()
|
||||||
|
return state not in [
|
||||||
|
ChannelState.TH1,
|
||||||
|
ChannelState.TH2,
|
||||||
|
ChannelState.UNALLOCATED,
|
||||||
|
]
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
|
|
||||||
def _log(self, text_1: str, text_2: str = "") -> None:
|
def _log(self, text_1: str, text_2: str = "") -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user