1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-12 09:28:10 +00:00

wip buffer locking-

[no changelog]
This commit is contained in:
M1nd3r 2024-12-11 11:46:55 +01:00
parent 7edc8dc0f6
commit fa8cf69c21
4 changed files with 226 additions and 161 deletions

View File

@ -67,6 +67,7 @@ class ChannelState(IntEnum):
TP4 = 6 TP4 = 6
TC1 = 7 TC1 = 7
ENCRYPTED_TRANSPORT = 8 ENCRYPTED_TRANSPORT = 8
INVALIDATED = 9
class SessionState(IntEnum): class SessionState(IntEnum):

View File

@ -8,7 +8,12 @@ from storage.cache_common import (
CHANNEL_NONCE_RECEIVE, CHANNEL_NONCE_RECEIVE,
CHANNEL_NONCE_SEND, CHANNEL_NONCE_SEND,
) )
from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id from storage.cache_thp import (
SESSION_ID_LENGTH,
TAG_LENGTH,
ChannelCache,
clear_sessions_with_channel_id,
)
from trezor import log, loop, protobuf, utils, workflow from trezor import log, loop, protobuf, utils, workflow
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
@ -25,6 +30,7 @@ from .transmission_loop import TransmissionLoop
from .writer import ( from .writer import (
CONT_HEADER_LENGTH, CONT_HEADER_LENGTH,
INIT_HEADER_LENGTH, INIT_HEADER_LENGTH,
MESSAGE_TYPE_LENGTH,
PACKET_LENGTH, PACKET_LENGTH,
write_payload_to_wire_and_add_checksum, write_payload_to_wire_and_add_checksum,
) )
@ -58,6 +64,7 @@ class Channel:
# Shared variables # Shared variables
self.buffer: utils.BufferType = bytearray(PACKET_LENGTH) self.buffer: utils.BufferType = bytearray(PACKET_LENGTH)
self.fallback_decrypt: bool = False
self.bytes_read: int = 0 self.bytes_read: int = 0
self.expected_payload_length: int = 0 self.expected_payload_length: int = 0
self.is_cont_packet_expected: bool = False self.is_cont_packet_expected: bool = False
@ -97,24 +104,25 @@ class Channel:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("set_channel_state: ", state_to_str(state)) self._log("set_channel_state: ", state_to_str(state))
def set_buffer(self, buffer: utils.BufferType) -> None:
self.buffer = buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("set_buffer: ", str(type(self.buffer)))
# READ and DECRYPT # READ and DECRYPT
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("receive packet") self._log("receive packet")
self._handle_received_packet(packet) self._handle_received_packet(packet)
try:
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
except BufferError:
pass # TODO ??
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("self.buffer: ", get_bytes_as_str(self.buffer)) self._log("self.buffer: ", get_bytes_as_str(buffer))
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
self._finish_message() self._finish_message()
return received_message_handler.handle_received_message(self, self.buffer) return received_message_handler.handle_received_message(self, buffer)
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
self.is_cont_packet_expected = True self.is_cont_packet_expected = True
else: else:
@ -136,7 +144,9 @@ class Channel:
# ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption # ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption
_, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) _, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet)
self.expected_payload_length = payload_length self.expected_payload_length = payload_length
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
# packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
# The above could be used for single packet decryption
# If the channel does not "own" the buffer lock, decrypt first packet # If the channel does not "own" the buffer lock, decrypt first packet
# TODO do it only when needed! # TODO do it only when needed!
@ -147,18 +157,22 @@ class Channel:
# if control_byte.is_encrypted_transport(ctrl_byte): # if control_byte.is_encrypted_transport(ctrl_byte):
# packet_payload = self._decrypt_single_packet_payload(packet_payload) # packet_payload = self._decrypt_single_packet_payload(packet_payload)
self.buffer = memory_manager.select_buffer( cid = self.get_channel_id_int()
self.get_channel_state(), length = payload_length + INIT_HEADER_LENGTH
self.buffer, try:
packet_payload, buffer = memory_manager.get_new_read_buffer(cid, length)
payload_length, except BufferError:
) self.fallback_decrypt = True
# TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess
# if buffer is BufferError:
# pass # TODO handle deviceBUSY
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("handle_init_packet - payload len: ", str(payload_length)) self._log("handle_init_packet - payload len: ", str(payload_length))
self._log("handle_init_packet - buffer len: ", str(len(self.buffer))) self._log("handle_init_packet - buffer len: ", str(len(buffer)))
return self._buffer_packet_data(self.buffer, packet, 0) self._buffer_packet_data(buffer, packet, 0)
def _handle_cont_packet(self, packet: utils.BufferType) -> None: def _handle_cont_packet(self, packet: utils.BufferType) -> None:
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
@ -166,7 +180,12 @@ class Channel:
if not self.is_cont_packet_expected: if not self.is_cont_packet_expected:
raise ThpError("Continuation packet is not expected, ignoring") raise ThpError("Continuation packet is not expected, ignoring")
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
try:
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
except BufferError:
pass # TODO handle device busy, channel kaput
return self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH)
def _buffer_packet_data( def _buffer_packet_data(
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
@ -174,6 +193,7 @@ class Channel:
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
def _finish_message(self) -> None: def _finish_message(self) -> None:
self.fallback_decrypt = False
self.bytes_read = 0 self.bytes_read = 0
self.expected_payload_length = 0 self.expected_payload_length = 0
self.is_cont_packet_expected = False self.is_cont_packet_expected = False
@ -187,15 +207,19 @@ class Channel:
def decrypt_buffer( def decrypt_buffer(
self, message_length: int, offset: int = INIT_HEADER_LENGTH self, message_length: int, offset: int = INIT_HEADER_LENGTH
) -> None: ) -> None:
noise_buffer = memoryview(self.buffer)[ buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
# if buffer is BufferError:
# pass # TODO handle deviceBUSY
noise_buffer = memoryview(buffer)[
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
] ]
tag = self.buffer[ tag = buffer[
message_length message_length
- CHECKSUM_LENGTH - CHECKSUM_LENGTH
- TAG_LENGTH : message_length - TAG_LENGTH : message_length
- CHECKSUM_LENGTH - CHECKSUM_LENGTH
] ]
if utils.DISABLE_ENCRYPTION: if utils.DISABLE_ENCRYPTION:
is_tag_valid = tag == crypto.DUMMY_TAG is_tag_valid = tag == crypto.DUMMY_TAG
else: else:
@ -234,11 +258,33 @@ class Channel:
if __debug__ and utils.EMULATOR: if __debug__ and utils.EMULATOR:
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg)) self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
self.buffer = memory_manager.get_write_buffer(self.buffer, msg) cid = self.get_channel_id_int()
noise_payload_len = memory_manager.encode_into_buffer( msg_size = protobuf.encoded_length(msg)
self.buffer, msg, session_id payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
) length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + INIT_HEADER_LENGTH
return self._write_and_encrypt(self.buffer[:noise_payload_len]) try:
buffer = memory_manager.get_new_write_buffer(cid, length)
noise_payload_len = memory_manager.encode_into_buffer(
buffer, msg, session_id
)
except BufferError:
from trezor.messages import Failure, FailureType
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
self._log("Failed to get write buffer, killing channel.")
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer,
Failure(
code=FailureType.FirmwareError,
message="Failed to obtain write buffer.",
),
session_id,
)
self.set_channel_state(ChannelState.INVALIDATED)
return self._write_and_encrypt(buffer[:noise_payload_len])
return self._write_and_encrypt(buffer[:noise_payload_len])
def write_error(self, err_type: int) -> Awaitable[None]: def write_error(self, err_type: int) -> Awaitable[None]:
msg_data = err_type.to_bytes(1, "big") msg_data = err_type.to_bytes(1, "big")
@ -254,7 +300,11 @@ class Channel:
def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]: def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]:
payload_length = len(payload) payload_length = len(payload)
self._encrypt(self.buffer, payload_length) buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
# if buffer is BufferError:
# pass # TODO handle deviceBUSY
self._encrypt(buffer, payload_length)
payload_length = payload_length + TAG_LENGTH payload_length = payload_length + TAG_LENGTH
if self.write_task_spawn is not None: if self.write_task_spawn is not None:
@ -263,7 +313,7 @@ class Channel:
self._prepare_write() self._prepare_write()
self.write_task_spawn = loop.spawn( self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop( self._write_encrypted_payload_loop(
ENCRYPTED, memoryview(self.buffer[:payload_length]) ENCRYPTED, memoryview(buffer[:payload_length])
) )
) )
return self.write_task_spawn return self.write_task_spawn

View File

@ -1,70 +1,150 @@
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH import utime
from trezor import log, protobuf, utils from micropython import const
from storage.cache_thp import SESSION_ID_LENGTH
from trezor import protobuf, utils
from trezor.wire.message_handler import get_msg_type from trezor.wire.message_handler import get_msg_type
from . import ChannelState, ThpError from . import ThpError
from .checksum import CHECKSUM_LENGTH from .writer import MAX_PAYLOAD_LEN, MESSAGE_TYPE_LENGTH
from .writer import (
INIT_HEADER_LENGTH,
MAX_PAYLOAD_LEN,
MESSAGE_TYPE_LENGTH,
PACKET_LENGTH,
)
_PROTOBUF_BUFFER_SIZE = 8192 _PROTOBUF_BUFFER_SIZE = 8192
READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
def select_buffer( lock_owner_cid: int | None = None
channel_state: int, lock_time: int = 0
channel_buffer: utils.BufferType,
packet_payload: utils.BufferType,
payload_length: int,
) -> utils.BufferType:
if channel_state is ChannelState.ENCRYPTED_TRANSPORT: READ_BUFFER_SLICE: memoryview | None = None
session_id = packet_payload[0] WRITE_BUFFER_SLICE: memoryview | None = None
if session_id == 0:
pass # Buffer types
# TODO use small buffer _READ: int = const(0)
else: _WRITE: int = const(1)
pass
# TODO use big buffer but only if the channel owns the buffer lock.
# Otherwise send BUSY message and return #
# Access to buffer slices
def get_new_read_buffer(channel_id: int, length: int) -> memoryview:
return _get_new_buffer(_READ, channel_id, length)
def get_new_write_buffer(channel_id: int, length: int) -> memoryview:
return _get_new_buffer(_WRITE, channel_id, length)
def get_existing_read_buffer(channel_id: int) -> memoryview:
return _get_existing_buffer(_READ, channel_id)
def get_existing_write_buffer(channel_id: int) -> memoryview:
return _get_existing_buffer(_WRITE, channel_id)
def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview:
if is_locked():
if not is_owner(channel_id):
raise BufferError
update_lock_time()
else: else:
pass update_lock(channel_id)
# TODO use small buffer
try: if buffer_type == _READ:
# TODO for now, we create a new big buffer every time. It should be changed global READ_BUFFER
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer) buffer = READ_BUFFER
return buffer elif buffer_type == _WRITE:
except Exception as e: global WRITE_BUFFER
if __debug__ and utils.ALLOW_DEBUG_MESSAGES: buffer = WRITE_BUFFER
log.exception(__name__, e) else:
raise Exception("Failed to create a buffer for channel") # TODO handle better raise Exception("Invalid buffer_type")
if length > MAX_PAYLOAD_LEN or length > len(buffer):
raise ThpError("Message is too large") # TODO reword
if buffer_type == _READ:
global READ_BUFFER_SLICE
READ_BUFFER_SLICE = memoryview(READ_BUFFER)[:length]
return READ_BUFFER_SLICE
if buffer_type == _WRITE:
global WRITE_BUFFER_SLICE
WRITE_BUFFER_SLICE = memoryview(WRITE_BUFFER)[:length]
return WRITE_BUFFER_SLICE
raise Exception("Invalid buffer_type")
def get_write_buffer( def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview:
buffer: utils.BufferType, msg: protobuf.MessageType if not is_owner(channel_id):
) -> utils.BufferType: raise BufferError
msg_size = protobuf.encoded_length(msg) update_lock_time()
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer): if buffer_type == _READ:
return _get_buffer_for_write(required_min_size, buffer) global READ_BUFFER_SLICE
return buffer if READ_BUFFER_SLICE is None:
raise BufferError
return READ_BUFFER_SLICE
if buffer_type == _WRITE:
global WRITE_BUFFER_SLICE
if WRITE_BUFFER_SLICE is None:
raise BufferError
return WRITE_BUFFER_SLICE
raise Exception("Invalid buffer_type")
#
# Buffer locking
def is_locked() -> bool:
global lock_owner_cid
global lock_time
time_diff = utime.ticks_diff(utime.ticks_ms(), lock_time)
return lock_owner_cid is not None and time_diff < 200
def is_owner(channel_id: int) -> bool:
global lock_owner_cid
return lock_owner_cid is not None and lock_owner_cid == channel_id
def update_lock(channel_id: int) -> None:
set_owner(channel_id)
update_lock_time()
def set_owner(channel_id: int) -> None:
global lock_owner_cid
lock_owner_cid = channel_id
def update_lock_time() -> None:
global lock_time
lock_time = utime.ticks_ms()
#
# Helper for encoding messages into buffer
def encode_into_buffer( def encode_into_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int: ) -> int:
"""Encode protobuf message `msg` into the `buffer`, including session id
an messages's wire type. Will fail if provided message has no wire type."""
# cannot write message without wire type # cannot write message without wire type
msg_type = msg.MESSAGE_WIRE_TYPE msg_type = msg.MESSAGE_WIRE_TYPE
if msg_type is None: if msg_type is None:
msg_type = get_msg_type(msg.MESSAGE_NAME) msg_type = get_msg_type(msg.MESSAGE_NAME)
assert msg_type is not None if msg_type is None:
raise Exception("Message has no wire type.")
msg_size = protobuf.encoded_length(msg) msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
@ -96,84 +176,3 @@ def _encode_message_into_buffer(
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
) -> None: ) -> None:
protobuf.encode(memoryview(buffer[buffer_offset:]), message) protobuf.encode(memoryview(buffer[buffer_offset:]), message)
def _get_buffer_for_read(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"get_buffer_for_read - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Allocating a new buffer")
if length > len(READ_BUFFER):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"Required length is %d, where raw buffer has capacity only %d",
length,
len(READ_BUFFER),
)
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(READ_BUFFER)[:length]
except MemoryError:
payload = memoryview(READ_BUFFER)[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]
def _get_buffer_for_write(
payload_length: int,
existing_buffer: utils.BufferType,
max_length: int = MAX_PAYLOAD_LEN,
) -> utils.BufferType:
length = payload_length + INIT_HEADER_LENGTH
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(
__name__,
"get_buffer_for_write - length: %d, %s %s",
length,
"existing buffer type:",
type(existing_buffer),
)
if length > max_length:
raise ThpError("Message too large")
if length > len(existing_buffer):
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Creating a new write buffer from raw write buffer")
if length > len(WRITE_BUFFER):
raise ThpError("Message is too large")
try:
payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length]
except MemoryError:
payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH]
raise ThpError("Message is too large")
return payload
# reuse a part of the supplied buffer
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
log.debug(__name__, "Reusing already allocated buffer")
return memoryview(existing_buffer)[:length]

View File

@ -18,6 +18,7 @@ from storage.cache_thp import (
from trezor import log, loop, protobuf, utils from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType from trezor.enums import FailureType
from trezor.messages import Failure from trezor.messages import Failure
from trezor.wire.thp import memory_manager
from .. import message_handler from .. import message_handler
from ..errors import DataError from ..errors import DataError
@ -227,8 +228,12 @@ async def _handle_state_TH1(
ctx.handshake = Handshake() ctx.handshake = Handshake()
buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
# if buffer is BufferError:
# pass # TODO buffer is gone :/
host_ephemeral_pubkey = bytearray( host_ephemeral_pubkey = bytearray(
ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH] buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
) )
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = ( trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
ctx.handshake.handle_th1_crypto( ctx.handshake.handle_th1_crypto(
@ -267,10 +272,13 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
if ctx.handshake is None: if ctx.handshake is None:
raise Exception("Handshake object is not prepared. Retry handshake.") raise Exception("Handshake object is not prepared. Retry handshake.")
host_encrypted_static_pubkey = memoryview(ctx.buffer)[ buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
# if buffer is BufferError:
# pass # TODO handle
host_encrypted_static_pubkey = buffer[
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
] ]
handshake_completion_request_noise_payload = memoryview(ctx.buffer)[ handshake_completion_request_noise_payload = buffer[
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
] ]
@ -285,7 +293,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1) ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
noise_payload = _decode_message( noise_payload = _decode_message(
ctx.buffer[ buffer[
INIT_HEADER_LENGTH INIT_HEADER_LENGTH
+ KEY_LENGTH + KEY_LENGTH
+ TAG_LENGTH : message_length + TAG_LENGTH : message_length
@ -349,8 +357,12 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
ctx.decrypt_buffer(message_length) ctx.decrypt_buffer(message_length)
buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
# if buffer is BufferError:
# pass # TODO handle
session_id, message_type = ustruct.unpack( session_id, message_type = ustruct.unpack(
">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:] ">BH", memoryview(buffer)[INIT_HEADER_LENGTH:]
) )
if session_id not in ctx.sessions: if session_id not in ctx.sessions:
@ -372,7 +384,7 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -
s.incoming_message.put( s.incoming_message.put(
Message( Message(
message_type, message_type,
ctx.buffer[ buffer[
INIT_HEADER_LENGTH INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH + MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length + SESSION_ID_LENGTH : message_length
@ -391,14 +403,17 @@ async def _handle_pairing(ctx: Channel, message_length: int) -> None:
loop.schedule(ctx.connection_context.handle()) loop.schedule(ctx.connection_context.handle())
ctx.decrypt_buffer(message_length) ctx.decrypt_buffer(message_length)
buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
# if buffer is BufferError:
# pass # TODO handle
message_type = ustruct.unpack( message_type = ustruct.unpack(
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :] ">H", buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
)[0] )[0]
ctx.connection_context.incoming_message.put( ctx.connection_context.incoming_message.put(
Message( Message(
message_type, message_type,
ctx.buffer[ buffer[
INIT_HEADER_LENGTH INIT_HEADER_LENGTH
+ MESSAGE_TYPE_LENGTH + MESSAGE_TYPE_LENGTH
+ SESSION_ID_LENGTH : message_length + SESSION_ID_LENGTH : message_length