mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
refactor(core): clean channel and received_message_handler
[no changelog]
This commit is contained in:
parent
5c7f5edb80
commit
517707a1c2
@ -76,6 +76,7 @@ class Channel:
|
|||||||
self.channel_cache.clear()
|
self.channel_cache.clear()
|
||||||
|
|
||||||
# ACCESS TO CHANNEL_DATA
|
# ACCESS TO CHANNEL_DATA
|
||||||
|
|
||||||
def get_channel_id_int(self) -> int:
|
def get_channel_id_int(self) -> int:
|
||||||
return int.from_bytes(self.channel_id, "big")
|
return int.from_bytes(self.channel_id, "big")
|
||||||
|
|
||||||
@ -100,7 +101,7 @@ class Channel:
|
|||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("set_buffer: ", str(type(self.buffer)))
|
self._log("set_buffer: ", str(type(self.buffer)))
|
||||||
|
|
||||||
# CALLED BY THP_MAIN_LOOP
|
# READ and DECRYPT
|
||||||
|
|
||||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
@ -108,7 +109,7 @@ class Channel:
|
|||||||
self._handle_received_packet(packet)
|
self._handle_received_packet(packet)
|
||||||
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("self.buffer: ", utils.get_bytes_as_str(self.buffer))
|
self._log("self.buffer: ", get_bytes_as_str(self.buffer))
|
||||||
|
|
||||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||||
self._finish_message()
|
self._finish_message()
|
||||||
@ -166,6 +167,16 @@ class Channel:
|
|||||||
raise ThpError("Continuation packet is not expected, ignoring")
|
raise ThpError("Continuation packet is not expected, ignoring")
|
||||||
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
||||||
|
|
||||||
|
def _buffer_packet_data(
|
||||||
|
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||||
|
) -> None:
|
||||||
|
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||||
|
|
||||||
|
def _finish_message(self) -> None:
|
||||||
|
self.bytes_read = 0
|
||||||
|
self.expected_payload_length = 0
|
||||||
|
self.is_cont_packet_expected = False
|
||||||
|
|
||||||
def _decrypt_single_packet_payload(
|
def _decrypt_single_packet_payload(
|
||||||
self, payload: utils.BufferType
|
self, payload: utils.BufferType
|
||||||
) -> utils.BufferType:
|
) -> utils.BufferType:
|
||||||
@ -212,42 +223,7 @@ class Channel:
|
|||||||
if not is_tag_valid:
|
if not is_tag_valid:
|
||||||
raise ThpDecryptionError()
|
raise ThpDecryptionError()
|
||||||
|
|
||||||
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
# WRITE and ENCRYPT
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
self._log("encrypt")
|
|
||||||
|
|
||||||
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
|
||||||
|
|
||||||
noise_buffer = memoryview(buffer)[0:noise_payload_len]
|
|
||||||
|
|
||||||
if utils.DISABLE_ENCRYPTION:
|
|
||||||
tag = crypto.DUMMY_TAG
|
|
||||||
else:
|
|
||||||
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
|
|
||||||
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
|
|
||||||
|
|
||||||
assert key_send is not None
|
|
||||||
assert nonce_send is not None
|
|
||||||
|
|
||||||
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
|
|
||||||
|
|
||||||
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
self._log("New nonce_send: ", str((nonce_send + 1)))
|
|
||||||
|
|
||||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
|
||||||
|
|
||||||
def _buffer_packet_data(
|
|
||||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
|
||||||
) -> None:
|
|
||||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
|
||||||
|
|
||||||
def _finish_message(self) -> None:
|
|
||||||
self.bytes_read = 0
|
|
||||||
self.expected_payload_length = 0
|
|
||||||
self.is_cont_packet_expected = False
|
|
||||||
|
|
||||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
self,
|
self,
|
||||||
@ -262,7 +238,7 @@ class Channel:
|
|||||||
noise_payload_len = memory_manager.encode_into_buffer(
|
noise_payload_len = memory_manager.encode_into_buffer(
|
||||||
self.buffer, msg, session_id
|
self.buffer, msg, session_id
|
||||||
)
|
)
|
||||||
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
|
task = self._write_and_encrypt(self.buffer[:noise_payload_len], force)
|
||||||
if task is not None:
|
if task is not None:
|
||||||
await task
|
await task
|
||||||
|
|
||||||
@ -272,7 +248,13 @@ class Channel:
|
|||||||
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
||||||
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
||||||
|
|
||||||
def write_and_encrypt(
|
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||||
|
self._prepare_write()
|
||||||
|
self.write_task_spawn = loop.spawn(
|
||||||
|
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _write_and_encrypt(
|
||||||
self, payload: bytes, force: bool = False
|
self, payload: bytes, force: bool = False
|
||||||
) -> Awaitable[None] | None:
|
) -> Awaitable[None] | None:
|
||||||
payload_length = len(payload)
|
payload_length = len(payload)
|
||||||
@ -297,12 +279,6 @@ class Channel:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
|
||||||
self._prepare_write()
|
|
||||||
self.write_task_spawn = loop.spawn(
|
|
||||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_write(self) -> None:
|
def _prepare_write(self) -> None:
|
||||||
# TODO add condition that disallows to write when can_send_message is false
|
# TODO add condition that disallows to write when can_send_message is false
|
||||||
ABP.set_sending_allowed(self.channel_cache, False)
|
ABP.set_sending_allowed(self.channel_cache, False)
|
||||||
@ -330,6 +306,31 @@ class Channel:
|
|||||||
|
|
||||||
loop.clear()
|
loop.clear()
|
||||||
|
|
||||||
|
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
||||||
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
|
self._log("encrypt")
|
||||||
|
|
||||||
|
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||||
|
|
||||||
|
noise_buffer = memoryview(buffer)[0:noise_payload_len]
|
||||||
|
|
||||||
|
if utils.DISABLE_ENCRYPTION:
|
||||||
|
tag = crypto.DUMMY_TAG
|
||||||
|
else:
|
||||||
|
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
|
||||||
|
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
|
||||||
|
|
||||||
|
assert key_send is not None
|
||||||
|
assert nonce_send is not None
|
||||||
|
|
||||||
|
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
|
||||||
|
|
||||||
|
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
||||||
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
|
self._log("New nonce_send: ", str((nonce_send + 1)))
|
||||||
|
|
||||||
|
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||||
|
|
||||||
def _can_clear_loop(self) -> bool:
|
def _can_clear_loop(self) -> bool:
|
||||||
return (
|
return (
|
||||||
not workflow.tasks
|
not workflow.tasks
|
||||||
@ -341,7 +342,7 @@ class Channel:
|
|||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"(cid: %s) %s%s",
|
"(cid: %s) %s%s",
|
||||||
utils.get_bytes_as_str(self.channel_id),
|
get_bytes_as_str(self.channel_id),
|
||||||
text_1,
|
text_1,
|
||||||
text_2,
|
text_2,
|
||||||
)
|
)
|
||||||
|
@ -60,9 +60,7 @@ if TYPE_CHECKING:
|
|||||||
from .channel import Channel
|
from .channel import Channel
|
||||||
|
|
||||||
if __debug__:
|
if __debug__:
|
||||||
from ubinascii import hexlify
|
from trezor.utils import get_bytes_as_str
|
||||||
|
|
||||||
from . import state_to_str
|
|
||||||
|
|
||||||
|
|
||||||
_TREZOR_STATE_UNPAIRED = b"\x00"
|
_TREZOR_STATE_UNPAIRED = b"\x00"
|
||||||
@ -198,8 +196,6 @@ def _handle_message_to_app_or_channel(
|
|||||||
ctrl_byte: int,
|
ctrl_byte: int,
|
||||||
) -> Awaitable[None]:
|
) -> Awaitable[None]:
|
||||||
state = ctx.get_channel_state()
|
state = ctx.get_channel_state()
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(__name__, "state: %s", state_to_str(state))
|
|
||||||
|
|
||||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||||
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
||||||
@ -244,14 +240,14 @@ async def _handle_state_TH1(
|
|||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"trezor ephemeral pubkey: %s",
|
"trezor ephemeral pubkey: %s",
|
||||||
hexlify(trezor_ephemeral_pubkey).decode(),
|
get_bytes_as_str(trezor_ephemeral_pubkey),
|
||||||
)
|
)
|
||||||
log.debug(
|
log.debug(
|
||||||
__name__,
|
__name__,
|
||||||
"encrypted trezor masked static pubkey: %s",
|
"encrypted trezor masked static pubkey: %s",
|
||||||
hexlify(encrypted_trezor_static_pubkey).decode(),
|
get_bytes_as_str(encrypted_trezor_static_pubkey),
|
||||||
)
|
)
|
||||||
log.debug(__name__, "tag: %s", hexlify(tag))
|
log.debug(__name__, "tag: %s", get_bytes_as_str(tag))
|
||||||
|
|
||||||
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
|
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user