mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-13 09:58:09 +00:00
refactor(core): clean channel and received_message_handler
[no changelog]
This commit is contained in:
parent
ffd9b16e2a
commit
c226f4242c
@ -76,6 +76,7 @@ class Channel:
|
||||
self.channel_cache.clear()
|
||||
|
||||
# ACCESS TO CHANNEL_DATA
|
||||
|
||||
def get_channel_id_int(self) -> int:
|
||||
return int.from_bytes(self.channel_id, "big")
|
||||
|
||||
@ -100,7 +101,7 @@ class Channel:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("set_buffer: ", str(type(self.buffer)))
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
# READ and DECRYPT
|
||||
|
||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
@ -108,7 +109,7 @@ class Channel:
|
||||
self._handle_received_packet(packet)
|
||||
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("self.buffer: ", utils.get_bytes_as_str(self.buffer))
|
||||
self._log("self.buffer: ", get_bytes_as_str(self.buffer))
|
||||
|
||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||
self._finish_message()
|
||||
@ -166,6 +167,16 @@ class Channel:
|
||||
raise ThpError("Continuation packet is not expected, ignoring")
|
||||
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
||||
|
||||
def _buffer_packet_data(
|
||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||
) -> None:
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
|
||||
def _finish_message(self) -> None:
|
||||
self.bytes_read = 0
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
def _decrypt_single_packet_payload(
|
||||
self, payload: utils.BufferType
|
||||
) -> utils.BufferType:
|
||||
@ -212,42 +223,7 @@ class Channel:
|
||||
if not is_tag_valid:
|
||||
raise ThpDecryptionError()
|
||||
|
||||
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("encrypt")
|
||||
|
||||
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
|
||||
noise_buffer = memoryview(buffer)[0:noise_payload_len]
|
||||
|
||||
if utils.DISABLE_ENCRYPTION:
|
||||
tag = crypto.DUMMY_TAG
|
||||
else:
|
||||
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
|
||||
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
|
||||
|
||||
assert key_send is not None
|
||||
assert nonce_send is not None
|
||||
|
||||
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
|
||||
|
||||
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("New nonce_send: ", str((nonce_send + 1)))
|
||||
|
||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||
|
||||
def _buffer_packet_data(
|
||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||
) -> None:
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
|
||||
def _finish_message(self) -> None:
|
||||
self.bytes_read = 0
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
# WRITE and ENCRYPT
|
||||
|
||||
async def write(
|
||||
self,
|
||||
@ -262,7 +238,7 @@ class Channel:
|
||||
noise_payload_len = memory_manager.encode_into_buffer(
|
||||
self.buffer, msg, session_id
|
||||
)
|
||||
task = self.write_and_encrypt(self.buffer[:noise_payload_len], force)
|
||||
task = self._write_and_encrypt(self.buffer[:noise_payload_len], force)
|
||||
if task is not None:
|
||||
await task
|
||||
|
||||
@ -272,7 +248,13 @@ class Channel:
|
||||
header = PacketHeader.get_error_header(self.get_channel_id_int(), length)
|
||||
return write_payload_to_wire_and_add_checksum(self.iface, header, msg_data)
|
||||
|
||||
def write_and_encrypt(
|
||||
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||
self._prepare_write()
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||
)
|
||||
|
||||
def _write_and_encrypt(
|
||||
self, payload: bytes, force: bool = False
|
||||
) -> Awaitable[None] | None:
|
||||
payload_length = len(payload)
|
||||
@ -297,12 +279,6 @@ class Channel:
|
||||
)
|
||||
return None
|
||||
|
||||
def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None:
|
||||
self._prepare_write()
|
||||
self.write_task_spawn = loop.spawn(
|
||||
self._write_encrypted_payload_loop(ctrl_byte, payload)
|
||||
)
|
||||
|
||||
def _prepare_write(self) -> None:
|
||||
# TODO add condition that disallows to write when can_send_message is false
|
||||
ABP.set_sending_allowed(self.channel_cache, False)
|
||||
@ -330,6 +306,31 @@ class Channel:
|
||||
|
||||
loop.clear()
|
||||
|
||||
def _encrypt(self, buffer: utils.BufferType, noise_payload_len: int) -> None:
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("encrypt")
|
||||
|
||||
assert len(buffer) >= noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
||||
|
||||
noise_buffer = memoryview(buffer)[0:noise_payload_len]
|
||||
|
||||
if utils.DISABLE_ENCRYPTION:
|
||||
tag = crypto.DUMMY_TAG
|
||||
else:
|
||||
key_send = self.channel_cache.get(CHANNEL_KEY_SEND)
|
||||
nonce_send = self.channel_cache.get_int(CHANNEL_NONCE_SEND)
|
||||
|
||||
assert key_send is not None
|
||||
assert nonce_send is not None
|
||||
|
||||
tag = crypto.enc(noise_buffer, key_send, nonce_send, b"")
|
||||
|
||||
self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1)
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
self._log("New nonce_send: ", str((nonce_send + 1)))
|
||||
|
||||
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
||||
|
||||
def _can_clear_loop(self) -> bool:
|
||||
return (
|
||||
not workflow.tasks
|
||||
@ -341,7 +342,7 @@ class Channel:
|
||||
log.debug(
|
||||
__name__,
|
||||
"(cid: %s) %s%s",
|
||||
utils.get_bytes_as_str(self.channel_id),
|
||||
get_bytes_as_str(self.channel_id),
|
||||
text_1,
|
||||
text_2,
|
||||
)
|
||||
|
@ -60,9 +60,7 @@ if TYPE_CHECKING:
|
||||
from .channel import Channel
|
||||
|
||||
if __debug__:
|
||||
from ubinascii import hexlify
|
||||
|
||||
from . import state_to_str
|
||||
from trezor.utils import get_bytes_as_str
|
||||
|
||||
|
||||
_TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
@ -198,8 +196,6 @@ def _handle_message_to_app_or_channel(
|
||||
ctrl_byte: int,
|
||||
) -> Awaitable[None]:
|
||||
state = ctx.get_channel_state()
|
||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||
log.debug(__name__, "state: %s", state_to_str(state))
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
return _handle_state_ENCRYPTED_TRANSPORT(ctx, message_length)
|
||||
@ -244,14 +240,14 @@ async def _handle_state_TH1(
|
||||
log.debug(
|
||||
__name__,
|
||||
"trezor ephemeral pubkey: %s",
|
||||
hexlify(trezor_ephemeral_pubkey).decode(),
|
||||
get_bytes_as_str(trezor_ephemeral_pubkey),
|
||||
)
|
||||
log.debug(
|
||||
__name__,
|
||||
"encrypted trezor masked static pubkey: %s",
|
||||
hexlify(encrypted_trezor_static_pubkey).decode(),
|
||||
get_bytes_as_str(encrypted_trezor_static_pubkey),
|
||||
)
|
||||
log.debug(__name__, "tag: %s", hexlify(tag))
|
||||
log.debug(__name__, "tag: %s", get_bytes_as_str(tag))
|
||||
|
||||
payload = trezor_ephemeral_pubkey + encrypted_trezor_static_pubkey + tag
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user