|
|
|
@ -12,7 +12,7 @@ from trezor.wire import message_handler
|
|
|
|
|
from trezor.wire.thp import thp_messages
|
|
|
|
|
|
|
|
|
|
from ..protocol_common import Context, MessageWithType
|
|
|
|
|
from . import ChannelState, SessionState, checksum
|
|
|
|
|
from . import ChannelState, SessionState, checksum, crypto
|
|
|
|
|
from . import thp_session as THP
|
|
|
|
|
from .checksum import CHECKSUM_LENGTH
|
|
|
|
|
from .thp_messages import (
|
|
|
|
@ -90,7 +90,8 @@ class Channel(Context):
|
|
|
|
|
await self._handle_cont_packet(packet)
|
|
|
|
|
else:
|
|
|
|
|
await self._handle_init_packet(packet)
|
|
|
|
|
|
|
|
|
|
print("receive packet", self.expected_payload_length, self.bytes_read)
|
|
|
|
|
printBytes(self.buffer)
|
|
|
|
|
if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read:
|
|
|
|
|
self._finish_message()
|
|
|
|
|
await self._handle_completed_message()
|
|
|
|
@ -103,7 +104,7 @@ class Channel(Context):
|
|
|
|
|
# If the channel does not "own" the buffer lock, decrypt first packet
|
|
|
|
|
# TODO do it only when needed!
|
|
|
|
|
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
|
|
|
|
packet_payload = self._decrypt(packet_payload)
|
|
|
|
|
packet_payload = self._decrypt_single_packet_payload(packet_payload)
|
|
|
|
|
|
|
|
|
|
state = self.get_channel_state()
|
|
|
|
|
|
|
|
|
@ -254,7 +255,7 @@ class Channel(Context):
|
|
|
|
|
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
|
|
|
|
|
|
|
|
|
def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None:
|
|
|
|
|
self._decrypt_buffer()
|
|
|
|
|
self._decrypt_buffer(message_length)
|
|
|
|
|
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
|
|
|
|
|
if session_id == 0:
|
|
|
|
|
self._handle_channel_comms(message_length, message_type)
|
|
|
|
@ -293,18 +294,43 @@ class Channel(Context):
|
|
|
|
|
bufferrone = bytearray(2)
|
|
|
|
|
message_size: int = thp_messages.get_new_session_message(bufferrone)
|
|
|
|
|
print(message_size) # TODO adjust
|
|
|
|
|
loop.schedule(self._write_encrypted_payload_loop(bufferrone))
|
|
|
|
|
loop.schedule(self.write_and_encrypt(bufferrone))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("Proč??")
|
|
|
|
|
print(e)
|
|
|
|
|
# TODO not finished
|
|
|
|
|
|
|
|
|
|
def _decrypt(self, payload) -> bytes:
|
|
|
|
|
return payload # TODO add decryption process
|
|
|
|
|
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
|
|
|
|
|
payload_buffer = bytearray(payload)
|
|
|
|
|
crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload))
|
|
|
|
|
return payload_buffer
|
|
|
|
|
|
|
|
|
|
def _decrypt_buffer(self, message_length: int) -> None:
|
|
|
|
|
if not isinstance(self.buffer, bytearray):
|
|
|
|
|
self.buffer = bytearray(self.buffer)
|
|
|
|
|
crypto.decrypt(
|
|
|
|
|
b"\x00",
|
|
|
|
|
b"\x00",
|
|
|
|
|
self.buffer,
|
|
|
|
|
INIT_DATA_OFFSET,
|
|
|
|
|
message_length - INIT_DATA_OFFSET - CHECKSUM_LENGTH,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _decrypt_buffer(self) -> None:
|
|
|
|
|
pass
|
|
|
|
|
# TODO decode buffer in place
|
|
|
|
|
def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None:
|
|
|
|
|
print("\n Encrypting ")
|
|
|
|
|
min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH
|
|
|
|
|
if len(buffer) < min_required_length or not isinstance(buffer, bytearray):
|
|
|
|
|
new_buffer = bytearray(min_required_length)
|
|
|
|
|
utils.memcpy(new_buffer, 0, buffer, 0)
|
|
|
|
|
buffer = new_buffer
|
|
|
|
|
tag = crypto.encrypt(
|
|
|
|
|
b"\x00",
|
|
|
|
|
b"\x00",
|
|
|
|
|
buffer,
|
|
|
|
|
0,
|
|
|
|
|
noise_payload_len,
|
|
|
|
|
)
|
|
|
|
|
buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag
|
|
|
|
|
|
|
|
|
|
async def _buffer_packet_data(
|
|
|
|
|
self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int
|
|
|
|
@ -327,7 +353,7 @@ class Channel(Context):
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__,
|
|
|
|
|
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
|
|
|
|
self.channel_id,
|
|
|
|
|
int.from_bytes(self.channel_id, "big"),
|
|
|
|
|
ack_bit,
|
|
|
|
|
)
|
|
|
|
|
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
|
|
|
|
@ -343,15 +369,18 @@ class Channel(Context):
|
|
|
|
|
|
|
|
|
|
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
|
|
|
|
print("write")
|
|
|
|
|
|
|
|
|
|
noise_payload_len = self._encode_into_buffer(msg, session_id)
|
|
|
|
|
await self.write_and_encrypt(self.buffer[:noise_payload_len])
|
|
|
|
|
|
|
|
|
|
# trezor.crypto.noise.encode(key, payload=self.buffer)
|
|
|
|
|
async def write_and_encrypt(self, payload: bytes) -> None:
|
|
|
|
|
payload_length = len(payload)
|
|
|
|
|
|
|
|
|
|
# TODO payload_len should be output from trezor.crypto.noise.encode, I guess
|
|
|
|
|
payload_len = noise_payload_len # + TAG_LENGTH # TODO
|
|
|
|
|
if not isinstance(self.buffer, bytearray):
|
|
|
|
|
self.buffer = bytearray(self.buffer)
|
|
|
|
|
self._encrypt(self.buffer, payload_length)
|
|
|
|
|
payload_length = payload_length + TAG_LENGTH
|
|
|
|
|
|
|
|
|
|
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
|
|
|
|
|
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length]))
|
|
|
|
|
|
|
|
|
|
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
|
|
|
|
print("write loop before while")
|
|
|
|
@ -419,10 +448,13 @@ class Channel(Context):
|
|
|
|
|
msg_size = protobuf.encoded_length(msg)
|
|
|
|
|
offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH
|
|
|
|
|
payload_size = offset + msg_size
|
|
|
|
|
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
|
|
|
|
|
|
|
|
|
|
if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray):
|
|
|
|
|
if required_min_size > len(self.buffer) or not isinstance(
|
|
|
|
|
self.buffer, bytearray
|
|
|
|
|
):
|
|
|
|
|
# message is too big or buffer is not bytearray, we need to allocate a new buffer
|
|
|
|
|
self.buffer = bytearray(payload_size)
|
|
|
|
|
self.buffer = bytearray(required_min_size)
|
|
|
|
|
|
|
|
|
|
buffer = self.buffer
|
|
|
|
|
session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big")
|
|
|
|
@ -445,6 +477,7 @@ class Channel(Context):
|
|
|
|
|
self.sessions[session.session_id] = session
|
|
|
|
|
loop.schedule(session.handle())
|
|
|
|
|
print("new session created. Session id:", session.session_id)
|
|
|
|
|
print(self.sessions)
|
|
|
|
|
|
|
|
|
|
def _todo_clear_buffer(self):
|
|
|
|
|
# TODO Buffer clearing not implemented
|
|
|
|
|