mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-12 09:28:10 +00:00
wip buffer locking-
[no changelog]
This commit is contained in:
parent
7edc8dc0f6
commit
fa8cf69c21
@ -67,6 +67,7 @@ class ChannelState(IntEnum):
|
|||||||
TP4 = 6
|
TP4 = 6
|
||||||
TC1 = 7
|
TC1 = 7
|
||||||
ENCRYPTED_TRANSPORT = 8
|
ENCRYPTED_TRANSPORT = 8
|
||||||
|
INVALIDATED = 9
|
||||||
|
|
||||||
|
|
||||||
class SessionState(IntEnum):
|
class SessionState(IntEnum):
|
||||||
|
@ -8,7 +8,12 @@ from storage.cache_common import (
|
|||||||
CHANNEL_NONCE_RECEIVE,
|
CHANNEL_NONCE_RECEIVE,
|
||||||
CHANNEL_NONCE_SEND,
|
CHANNEL_NONCE_SEND,
|
||||||
)
|
)
|
||||||
from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id
|
from storage.cache_thp import (
|
||||||
|
SESSION_ID_LENGTH,
|
||||||
|
TAG_LENGTH,
|
||||||
|
ChannelCache,
|
||||||
|
clear_sessions_with_channel_id,
|
||||||
|
)
|
||||||
from trezor import log, loop, protobuf, utils, workflow
|
from trezor import log, loop, protobuf, utils, workflow
|
||||||
|
|
||||||
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
|
from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError
|
||||||
@ -25,6 +30,7 @@ from .transmission_loop import TransmissionLoop
|
|||||||
from .writer import (
|
from .writer import (
|
||||||
CONT_HEADER_LENGTH,
|
CONT_HEADER_LENGTH,
|
||||||
INIT_HEADER_LENGTH,
|
INIT_HEADER_LENGTH,
|
||||||
|
MESSAGE_TYPE_LENGTH,
|
||||||
PACKET_LENGTH,
|
PACKET_LENGTH,
|
||||||
write_payload_to_wire_and_add_checksum,
|
write_payload_to_wire_and_add_checksum,
|
||||||
)
|
)
|
||||||
@ -58,6 +64,7 @@ class Channel:
|
|||||||
|
|
||||||
# Shared variables
|
# Shared variables
|
||||||
self.buffer: utils.BufferType = bytearray(PACKET_LENGTH)
|
self.buffer: utils.BufferType = bytearray(PACKET_LENGTH)
|
||||||
|
self.fallback_decrypt: bool = False
|
||||||
self.bytes_read: int = 0
|
self.bytes_read: int = 0
|
||||||
self.expected_payload_length: int = 0
|
self.expected_payload_length: int = 0
|
||||||
self.is_cont_packet_expected: bool = False
|
self.is_cont_packet_expected: bool = False
|
||||||
@ -97,24 +104,25 @@ class Channel:
|
|||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("set_channel_state: ", state_to_str(state))
|
self._log("set_channel_state: ", state_to_str(state))
|
||||||
|
|
||||||
def set_buffer(self, buffer: utils.BufferType) -> None:
|
|
||||||
self.buffer = buffer
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
self._log("set_buffer: ", str(type(self.buffer)))
|
|
||||||
|
|
||||||
# READ and DECRYPT
|
# READ and DECRYPT
|
||||||
|
|
||||||
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None:
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("receive packet")
|
self._log("receive packet")
|
||||||
|
|
||||||
self._handle_received_packet(packet)
|
self._handle_received_packet(packet)
|
||||||
|
|
||||||
|
try:
|
||||||
|
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||||
|
except BufferError:
|
||||||
|
pass # TODO ??
|
||||||
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("self.buffer: ", get_bytes_as_str(self.buffer))
|
self._log("self.buffer: ", get_bytes_as_str(buffer))
|
||||||
|
|
||||||
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read:
|
||||||
self._finish_message()
|
self._finish_message()
|
||||||
return received_message_handler.handle_received_message(self, self.buffer)
|
return received_message_handler.handle_received_message(self, buffer)
|
||||||
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read:
|
||||||
self.is_cont_packet_expected = True
|
self.is_cont_packet_expected = True
|
||||||
else:
|
else:
|
||||||
@ -136,7 +144,9 @@ class Channel:
|
|||||||
# ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption
|
# ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption
|
||||||
_, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet)
|
_, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet)
|
||||||
self.expected_payload_length = payload_length
|
self.expected_payload_length = payload_length
|
||||||
packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
|
|
||||||
|
# packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:]
|
||||||
|
# The above could be used for single packet decryption
|
||||||
|
|
||||||
# If the channel does not "own" the buffer lock, decrypt first packet
|
# If the channel does not "own" the buffer lock, decrypt first packet
|
||||||
# TODO do it only when needed!
|
# TODO do it only when needed!
|
||||||
@ -147,18 +157,22 @@ class Channel:
|
|||||||
# if control_byte.is_encrypted_transport(ctrl_byte):
|
# if control_byte.is_encrypted_transport(ctrl_byte):
|
||||||
# packet_payload = self._decrypt_single_packet_payload(packet_payload)
|
# packet_payload = self._decrypt_single_packet_payload(packet_payload)
|
||||||
|
|
||||||
self.buffer = memory_manager.select_buffer(
|
cid = self.get_channel_id_int()
|
||||||
self.get_channel_state(),
|
length = payload_length + INIT_HEADER_LENGTH
|
||||||
self.buffer,
|
try:
|
||||||
packet_payload,
|
buffer = memory_manager.get_new_read_buffer(cid, length)
|
||||||
payload_length,
|
except BufferError:
|
||||||
)
|
self.fallback_decrypt = True
|
||||||
|
# TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess
|
||||||
|
|
||||||
|
# if buffer is BufferError:
|
||||||
|
# pass # TODO handle deviceBUSY
|
||||||
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
self._log("handle_init_packet - payload len: ", str(payload_length))
|
self._log("handle_init_packet - payload len: ", str(payload_length))
|
||||||
self._log("handle_init_packet - buffer len: ", str(len(self.buffer)))
|
self._log("handle_init_packet - buffer len: ", str(len(buffer)))
|
||||||
|
|
||||||
return self._buffer_packet_data(self.buffer, packet, 0)
|
self._buffer_packet_data(buffer, packet, 0)
|
||||||
|
|
||||||
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
def _handle_cont_packet(self, packet: utils.BufferType) -> None:
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
@ -166,7 +180,12 @@ class Channel:
|
|||||||
|
|
||||||
if not self.is_cont_packet_expected:
|
if not self.is_cont_packet_expected:
|
||||||
raise ThpError("Continuation packet is not expected, ignoring")
|
raise ThpError("Continuation packet is not expected, ignoring")
|
||||||
return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH)
|
|
||||||
|
try:
|
||||||
|
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||||
|
except BufferError:
|
||||||
|
pass # TODO handle device busy, channel kaput
|
||||||
|
return self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH)
|
||||||
|
|
||||||
def _buffer_packet_data(
|
def _buffer_packet_data(
|
||||||
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
||||||
@ -174,6 +193,7 @@ class Channel:
|
|||||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||||
|
|
||||||
def _finish_message(self) -> None:
|
def _finish_message(self) -> None:
|
||||||
|
self.fallback_decrypt = False
|
||||||
self.bytes_read = 0
|
self.bytes_read = 0
|
||||||
self.expected_payload_length = 0
|
self.expected_payload_length = 0
|
||||||
self.is_cont_packet_expected = False
|
self.is_cont_packet_expected = False
|
||||||
@ -187,15 +207,19 @@ class Channel:
|
|||||||
def decrypt_buffer(
|
def decrypt_buffer(
|
||||||
self, message_length: int, offset: int = INIT_HEADER_LENGTH
|
self, message_length: int, offset: int = INIT_HEADER_LENGTH
|
||||||
) -> None:
|
) -> None:
|
||||||
noise_buffer = memoryview(self.buffer)[
|
buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int())
|
||||||
|
# if buffer is BufferError:
|
||||||
|
# pass # TODO handle deviceBUSY
|
||||||
|
noise_buffer = memoryview(buffer)[
|
||||||
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
||||||
]
|
]
|
||||||
tag = self.buffer[
|
tag = buffer[
|
||||||
message_length
|
message_length
|
||||||
- CHECKSUM_LENGTH
|
- CHECKSUM_LENGTH
|
||||||
- TAG_LENGTH : message_length
|
- TAG_LENGTH : message_length
|
||||||
- CHECKSUM_LENGTH
|
- CHECKSUM_LENGTH
|
||||||
]
|
]
|
||||||
|
|
||||||
if utils.DISABLE_ENCRYPTION:
|
if utils.DISABLE_ENCRYPTION:
|
||||||
is_tag_valid = tag == crypto.DUMMY_TAG
|
is_tag_valid = tag == crypto.DUMMY_TAG
|
||||||
else:
|
else:
|
||||||
@ -234,11 +258,33 @@ class Channel:
|
|||||||
if __debug__ and utils.EMULATOR:
|
if __debug__ and utils.EMULATOR:
|
||||||
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
|
self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg))
|
||||||
|
|
||||||
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
|
cid = self.get_channel_id_int()
|
||||||
|
msg_size = protobuf.encoded_length(msg)
|
||||||
|
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
||||||
|
length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + INIT_HEADER_LENGTH
|
||||||
|
try:
|
||||||
|
buffer = memory_manager.get_new_write_buffer(cid, length)
|
||||||
noise_payload_len = memory_manager.encode_into_buffer(
|
noise_payload_len = memory_manager.encode_into_buffer(
|
||||||
self.buffer, msg, session_id
|
buffer, msg, session_id
|
||||||
)
|
)
|
||||||
return self._write_and_encrypt(self.buffer[:noise_payload_len])
|
except BufferError:
|
||||||
|
from trezor.messages import Failure, FailureType
|
||||||
|
|
||||||
|
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
||||||
|
self._log("Failed to get write buffer, killing channel.")
|
||||||
|
|
||||||
|
noise_payload_len = memory_manager.encode_into_buffer(
|
||||||
|
self.buffer,
|
||||||
|
Failure(
|
||||||
|
code=FailureType.FirmwareError,
|
||||||
|
message="Failed to obtain write buffer.",
|
||||||
|
),
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
self.set_channel_state(ChannelState.INVALIDATED)
|
||||||
|
return self._write_and_encrypt(buffer[:noise_payload_len])
|
||||||
|
|
||||||
|
return self._write_and_encrypt(buffer[:noise_payload_len])
|
||||||
|
|
||||||
def write_error(self, err_type: int) -> Awaitable[None]:
|
def write_error(self, err_type: int) -> Awaitable[None]:
|
||||||
msg_data = err_type.to_bytes(1, "big")
|
msg_data = err_type.to_bytes(1, "big")
|
||||||
@ -254,7 +300,11 @@ class Channel:
|
|||||||
|
|
||||||
def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]:
|
def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]:
|
||||||
payload_length = len(payload)
|
payload_length = len(payload)
|
||||||
self._encrypt(self.buffer, payload_length)
|
buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int())
|
||||||
|
# if buffer is BufferError:
|
||||||
|
# pass # TODO handle deviceBUSY
|
||||||
|
|
||||||
|
self._encrypt(buffer, payload_length)
|
||||||
payload_length = payload_length + TAG_LENGTH
|
payload_length = payload_length + TAG_LENGTH
|
||||||
|
|
||||||
if self.write_task_spawn is not None:
|
if self.write_task_spawn is not None:
|
||||||
@ -263,7 +313,7 @@ class Channel:
|
|||||||
self._prepare_write()
|
self._prepare_write()
|
||||||
self.write_task_spawn = loop.spawn(
|
self.write_task_spawn = loop.spawn(
|
||||||
self._write_encrypted_payload_loop(
|
self._write_encrypted_payload_loop(
|
||||||
ENCRYPTED, memoryview(self.buffer[:payload_length])
|
ENCRYPTED, memoryview(buffer[:payload_length])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return self.write_task_spawn
|
return self.write_task_spawn
|
||||||
|
@ -1,70 +1,150 @@
|
|||||||
from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH
|
import utime
|
||||||
from trezor import log, protobuf, utils
|
from micropython import const
|
||||||
|
|
||||||
|
from storage.cache_thp import SESSION_ID_LENGTH
|
||||||
|
from trezor import protobuf, utils
|
||||||
from trezor.wire.message_handler import get_msg_type
|
from trezor.wire.message_handler import get_msg_type
|
||||||
|
|
||||||
from . import ChannelState, ThpError
|
from . import ThpError
|
||||||
from .checksum import CHECKSUM_LENGTH
|
from .writer import MAX_PAYLOAD_LEN, MESSAGE_TYPE_LENGTH
|
||||||
from .writer import (
|
|
||||||
INIT_HEADER_LENGTH,
|
|
||||||
MAX_PAYLOAD_LEN,
|
|
||||||
MESSAGE_TYPE_LENGTH,
|
|
||||||
PACKET_LENGTH,
|
|
||||||
)
|
|
||||||
|
|
||||||
_PROTOBUF_BUFFER_SIZE = 8192
|
_PROTOBUF_BUFFER_SIZE = 8192
|
||||||
READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||||
|
|
||||||
|
|
||||||
def select_buffer(
|
lock_owner_cid: int | None = None
|
||||||
channel_state: int,
|
lock_time: int = 0
|
||||||
channel_buffer: utils.BufferType,
|
|
||||||
packet_payload: utils.BufferType,
|
|
||||||
payload_length: int,
|
|
||||||
) -> utils.BufferType:
|
|
||||||
|
|
||||||
if channel_state is ChannelState.ENCRYPTED_TRANSPORT:
|
READ_BUFFER_SLICE: memoryview | None = None
|
||||||
session_id = packet_payload[0]
|
WRITE_BUFFER_SLICE: memoryview | None = None
|
||||||
if session_id == 0:
|
|
||||||
pass
|
# Buffer types
|
||||||
# TODO use small buffer
|
_READ: int = const(0)
|
||||||
|
_WRITE: int = const(1)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Access to buffer slices
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_read_buffer(channel_id: int, length: int) -> memoryview:
|
||||||
|
return _get_new_buffer(_READ, channel_id, length)
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_write_buffer(channel_id: int, length: int) -> memoryview:
|
||||||
|
return _get_new_buffer(_WRITE, channel_id, length)
|
||||||
|
|
||||||
|
|
||||||
|
def get_existing_read_buffer(channel_id: int) -> memoryview:
|
||||||
|
return _get_existing_buffer(_READ, channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_existing_write_buffer(channel_id: int) -> memoryview:
|
||||||
|
return _get_existing_buffer(_WRITE, channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview:
|
||||||
|
if is_locked():
|
||||||
|
if not is_owner(channel_id):
|
||||||
|
raise BufferError
|
||||||
|
update_lock_time()
|
||||||
else:
|
else:
|
||||||
pass
|
update_lock(channel_id)
|
||||||
# TODO use big buffer but only if the channel owns the buffer lock.
|
|
||||||
# Otherwise send BUSY message and return
|
if buffer_type == _READ:
|
||||||
|
global READ_BUFFER
|
||||||
|
buffer = READ_BUFFER
|
||||||
|
elif buffer_type == _WRITE:
|
||||||
|
global WRITE_BUFFER
|
||||||
|
buffer = WRITE_BUFFER
|
||||||
else:
|
else:
|
||||||
pass
|
raise Exception("Invalid buffer_type")
|
||||||
# TODO use small buffer
|
|
||||||
try:
|
if length > MAX_PAYLOAD_LEN or length > len(buffer):
|
||||||
# TODO for now, we create a new big buffer every time. It should be changed
|
raise ThpError("Message is too large") # TODO reword
|
||||||
buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer)
|
|
||||||
return buffer
|
if buffer_type == _READ:
|
||||||
except Exception as e:
|
global READ_BUFFER_SLICE
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
READ_BUFFER_SLICE = memoryview(READ_BUFFER)[:length]
|
||||||
log.exception(__name__, e)
|
return READ_BUFFER_SLICE
|
||||||
raise Exception("Failed to create a buffer for channel") # TODO handle better
|
|
||||||
|
if buffer_type == _WRITE:
|
||||||
|
global WRITE_BUFFER_SLICE
|
||||||
|
WRITE_BUFFER_SLICE = memoryview(WRITE_BUFFER)[:length]
|
||||||
|
return WRITE_BUFFER_SLICE
|
||||||
|
|
||||||
|
raise Exception("Invalid buffer_type")
|
||||||
|
|
||||||
|
|
||||||
def get_write_buffer(
|
def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview:
|
||||||
buffer: utils.BufferType, msg: protobuf.MessageType
|
if not is_owner(channel_id):
|
||||||
) -> utils.BufferType:
|
raise BufferError
|
||||||
msg_size = protobuf.encoded_length(msg)
|
update_lock_time()
|
||||||
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
|
||||||
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
|
|
||||||
|
|
||||||
if required_min_size > len(buffer):
|
if buffer_type == _READ:
|
||||||
return _get_buffer_for_write(required_min_size, buffer)
|
global READ_BUFFER_SLICE
|
||||||
return buffer
|
if READ_BUFFER_SLICE is None:
|
||||||
|
raise BufferError
|
||||||
|
return READ_BUFFER_SLICE
|
||||||
|
|
||||||
|
if buffer_type == _WRITE:
|
||||||
|
global WRITE_BUFFER_SLICE
|
||||||
|
if WRITE_BUFFER_SLICE is None:
|
||||||
|
raise BufferError
|
||||||
|
return WRITE_BUFFER_SLICE
|
||||||
|
|
||||||
|
raise Exception("Invalid buffer_type")
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Buffer locking
|
||||||
|
|
||||||
|
|
||||||
|
def is_locked() -> bool:
|
||||||
|
global lock_owner_cid
|
||||||
|
global lock_time
|
||||||
|
|
||||||
|
time_diff = utime.ticks_diff(utime.ticks_ms(), lock_time)
|
||||||
|
return lock_owner_cid is not None and time_diff < 200
|
||||||
|
|
||||||
|
|
||||||
|
def is_owner(channel_id: int) -> bool:
|
||||||
|
global lock_owner_cid
|
||||||
|
return lock_owner_cid is not None and lock_owner_cid == channel_id
|
||||||
|
|
||||||
|
|
||||||
|
def update_lock(channel_id: int) -> None:
|
||||||
|
set_owner(channel_id)
|
||||||
|
update_lock_time()
|
||||||
|
|
||||||
|
|
||||||
|
def set_owner(channel_id: int) -> None:
|
||||||
|
global lock_owner_cid
|
||||||
|
lock_owner_cid = channel_id
|
||||||
|
|
||||||
|
|
||||||
|
def update_lock_time() -> None:
|
||||||
|
global lock_time
|
||||||
|
lock_time = utime.ticks_ms()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Helper for encoding messages into buffer
|
||||||
|
|
||||||
|
|
||||||
def encode_into_buffer(
|
def encode_into_buffer(
|
||||||
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
|
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
|
||||||
) -> int:
|
) -> int:
|
||||||
|
"""Encode protobuf message `msg` into the `buffer`, including session id
|
||||||
|
an messages's wire type. Will fail if provided message has no wire type."""
|
||||||
|
|
||||||
# cannot write message without wire type
|
# cannot write message without wire type
|
||||||
msg_type = msg.MESSAGE_WIRE_TYPE
|
msg_type = msg.MESSAGE_WIRE_TYPE
|
||||||
if msg_type is None:
|
if msg_type is None:
|
||||||
msg_type = get_msg_type(msg.MESSAGE_NAME)
|
msg_type = get_msg_type(msg.MESSAGE_NAME)
|
||||||
assert msg_type is not None
|
if msg_type is None:
|
||||||
|
raise Exception("Message has no wire type.")
|
||||||
|
|
||||||
msg_size = protobuf.encoded_length(msg)
|
msg_size = protobuf.encoded_length(msg)
|
||||||
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
|
||||||
@ -96,84 +176,3 @@ def _encode_message_into_buffer(
|
|||||||
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
|
buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0
|
||||||
) -> None:
|
) -> None:
|
||||||
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
|
protobuf.encode(memoryview(buffer[buffer_offset:]), message)
|
||||||
|
|
||||||
|
|
||||||
def _get_buffer_for_read(
|
|
||||||
payload_length: int,
|
|
||||||
existing_buffer: utils.BufferType,
|
|
||||||
max_length: int = MAX_PAYLOAD_LEN,
|
|
||||||
) -> utils.BufferType:
|
|
||||||
length = payload_length + INIT_HEADER_LENGTH
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"get_buffer_for_read - length: %d, %s %s",
|
|
||||||
length,
|
|
||||||
"existing buffer type:",
|
|
||||||
type(existing_buffer),
|
|
||||||
)
|
|
||||||
if length > max_length:
|
|
||||||
raise ThpError("Message too large")
|
|
||||||
|
|
||||||
if length > len(existing_buffer):
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(__name__, "Allocating a new buffer")
|
|
||||||
|
|
||||||
if length > len(READ_BUFFER):
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"Required length is %d, where raw buffer has capacity only %d",
|
|
||||||
length,
|
|
||||||
len(READ_BUFFER),
|
|
||||||
)
|
|
||||||
raise ThpError("Message is too large")
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload: utils.BufferType = memoryview(READ_BUFFER)[:length]
|
|
||||||
except MemoryError:
|
|
||||||
payload = memoryview(READ_BUFFER)[:PACKET_LENGTH]
|
|
||||||
raise ThpError("Message is too large")
|
|
||||||
return payload
|
|
||||||
|
|
||||||
# reuse a part of the supplied buffer
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(__name__, "Reusing already allocated buffer")
|
|
||||||
return memoryview(existing_buffer)[:length]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_buffer_for_write(
|
|
||||||
payload_length: int,
|
|
||||||
existing_buffer: utils.BufferType,
|
|
||||||
max_length: int = MAX_PAYLOAD_LEN,
|
|
||||||
) -> utils.BufferType:
|
|
||||||
length = payload_length + INIT_HEADER_LENGTH
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(
|
|
||||||
__name__,
|
|
||||||
"get_buffer_for_write - length: %d, %s %s",
|
|
||||||
length,
|
|
||||||
"existing buffer type:",
|
|
||||||
type(existing_buffer),
|
|
||||||
)
|
|
||||||
if length > max_length:
|
|
||||||
raise ThpError("Message too large")
|
|
||||||
|
|
||||||
if length > len(existing_buffer):
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(__name__, "Creating a new write buffer from raw write buffer")
|
|
||||||
|
|
||||||
if length > len(WRITE_BUFFER):
|
|
||||||
raise ThpError("Message is too large")
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length]
|
|
||||||
except MemoryError:
|
|
||||||
payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH]
|
|
||||||
raise ThpError("Message is too large")
|
|
||||||
return payload
|
|
||||||
|
|
||||||
# reuse a part of the supplied buffer
|
|
||||||
if __debug__ and utils.ALLOW_DEBUG_MESSAGES:
|
|
||||||
log.debug(__name__, "Reusing already allocated buffer")
|
|
||||||
return memoryview(existing_buffer)[:length]
|
|
||||||
|
@ -18,6 +18,7 @@ from storage.cache_thp import (
|
|||||||
from trezor import log, loop, protobuf, utils
|
from trezor import log, loop, protobuf, utils
|
||||||
from trezor.enums import FailureType
|
from trezor.enums import FailureType
|
||||||
from trezor.messages import Failure
|
from trezor.messages import Failure
|
||||||
|
from trezor.wire.thp import memory_manager
|
||||||
|
|
||||||
from .. import message_handler
|
from .. import message_handler
|
||||||
from ..errors import DataError
|
from ..errors import DataError
|
||||||
@ -227,8 +228,12 @@ async def _handle_state_TH1(
|
|||||||
|
|
||||||
ctx.handshake = Handshake()
|
ctx.handshake = Handshake()
|
||||||
|
|
||||||
|
buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
|
||||||
|
# if buffer is BufferError:
|
||||||
|
# pass # TODO buffer is gone :/
|
||||||
|
|
||||||
host_ephemeral_pubkey = bytearray(
|
host_ephemeral_pubkey = bytearray(
|
||||||
ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
|
buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH]
|
||||||
)
|
)
|
||||||
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
|
trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = (
|
||||||
ctx.handshake.handle_th1_crypto(
|
ctx.handshake.handle_th1_crypto(
|
||||||
@ -267,10 +272,13 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
if ctx.handshake is None:
|
if ctx.handshake is None:
|
||||||
raise Exception("Handshake object is not prepared. Retry handshake.")
|
raise Exception("Handshake object is not prepared. Retry handshake.")
|
||||||
|
|
||||||
host_encrypted_static_pubkey = memoryview(ctx.buffer)[
|
buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
|
||||||
|
# if buffer is BufferError:
|
||||||
|
# pass # TODO handle
|
||||||
|
host_encrypted_static_pubkey = buffer[
|
||||||
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
|
INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH
|
||||||
]
|
]
|
||||||
handshake_completion_request_noise_payload = memoryview(ctx.buffer)[
|
handshake_completion_request_noise_payload = buffer[
|
||||||
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
|
INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -285,7 +293,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) -
|
|||||||
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
|
ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1)
|
||||||
|
|
||||||
noise_payload = _decode_message(
|
noise_payload = _decode_message(
|
||||||
ctx.buffer[
|
buffer[
|
||||||
INIT_HEADER_LENGTH
|
INIT_HEADER_LENGTH
|
||||||
+ KEY_LENGTH
|
+ KEY_LENGTH
|
||||||
+ TAG_LENGTH : message_length
|
+ TAG_LENGTH : message_length
|
||||||
@ -349,8 +357,12 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -
|
|||||||
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT")
|
||||||
|
|
||||||
ctx.decrypt_buffer(message_length)
|
ctx.decrypt_buffer(message_length)
|
||||||
|
|
||||||
|
buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
|
||||||
|
# if buffer is BufferError:
|
||||||
|
# pass # TODO handle
|
||||||
session_id, message_type = ustruct.unpack(
|
session_id, message_type = ustruct.unpack(
|
||||||
">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:]
|
">BH", memoryview(buffer)[INIT_HEADER_LENGTH:]
|
||||||
)
|
)
|
||||||
if session_id not in ctx.sessions:
|
if session_id not in ctx.sessions:
|
||||||
|
|
||||||
@ -372,7 +384,7 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) -
|
|||||||
s.incoming_message.put(
|
s.incoming_message.put(
|
||||||
Message(
|
Message(
|
||||||
message_type,
|
message_type,
|
||||||
ctx.buffer[
|
buffer[
|
||||||
INIT_HEADER_LENGTH
|
INIT_HEADER_LENGTH
|
||||||
+ MESSAGE_TYPE_LENGTH
|
+ MESSAGE_TYPE_LENGTH
|
||||||
+ SESSION_ID_LENGTH : message_length
|
+ SESSION_ID_LENGTH : message_length
|
||||||
@ -391,14 +403,17 @@ async def _handle_pairing(ctx: Channel, message_length: int) -> None:
|
|||||||
loop.schedule(ctx.connection_context.handle())
|
loop.schedule(ctx.connection_context.handle())
|
||||||
|
|
||||||
ctx.decrypt_buffer(message_length)
|
ctx.decrypt_buffer(message_length)
|
||||||
|
buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int())
|
||||||
|
# if buffer is BufferError:
|
||||||
|
# pass # TODO handle
|
||||||
message_type = ustruct.unpack(
|
message_type = ustruct.unpack(
|
||||||
">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
|
">H", buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :]
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
ctx.connection_context.incoming_message.put(
|
ctx.connection_context.incoming_message.put(
|
||||||
Message(
|
Message(
|
||||||
message_type,
|
message_type,
|
||||||
ctx.buffer[
|
buffer[
|
||||||
INIT_HEADER_LENGTH
|
INIT_HEADER_LENGTH
|
||||||
+ MESSAGE_TYPE_LENGTH
|
+ MESSAGE_TYPE_LENGTH
|
||||||
+ SESSION_ID_LENGTH : message_length
|
+ SESSION_ID_LENGTH : message_length
|
||||||
|
Loading…
Reference in New Issue
Block a user