1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): create new session response

This commit is contained in:
M1nd3r 2024-04-02 20:40:22 +02:00
parent e7f5f3d3f2
commit 775ad59630
2 changed files with 70 additions and 36 deletions

View File

@ -41,8 +41,6 @@ MESSAGE_TYPE_LENGTH = const(2)
REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
_ACK_MESSAGE = 0x20
class Channel(Context):
def __init__(self, channel_cache: ChannelCache) -> None:
@ -163,21 +161,17 @@ class Channel(Context):
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
await self._sendAck(sync_bit)
await self._send_ack(sync_bit)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.channel_id,
sync_bit,
)
await self._sendAck(sync_bit)
await self._send_ack(sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
self._handle_valid_message(payload_length, message_length, ctrl_byte)
await self._handle_valid_message(
payload_length, message_length, ctrl_byte, sync_bit
)
print("end handle completed message")
def _check_checksum(self, message_length: int):
@ -189,15 +183,16 @@ class Channel(Context):
self._todo_clear_buffer()
raise ThpError("Invalid checksum, ignoring message.")
def _handle_valid_message(
self, payload_length: int, message_length: int, ctrl_byte: int
async def _handle_valid_message(
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
) -> None:
state = self.get_channel_state()
if __debug__:
log.debug(__name__, _state_to_str(state))
if state is ChannelState.TH1:
self._handle_state_TH1(payload_length, message_length)
await self._handle_state_TH1(payload_length, message_length, sync_bit)
return
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
self._todo_clear_buffer()
@ -205,11 +200,15 @@ class Channel(Context):
if state is ChannelState.ENCRYPTED_TRANSPORT:
self._handle_state_ENCRYPTED_TRANSPORT(message_length)
return
if state is ChannelState.TH2:
self._handle_state_TH2(message_length)
await self._handle_state_TH2(message_length, sync_bit)
return
def _handle_state_TH1(self, payload_length: int, message_length: int) -> None:
async def _handle_state_TH1(
self, payload_length: int, message_length: int, sync_bit: int
) -> None:
if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
@ -218,8 +217,10 @@ class Channel(Context):
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
)
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
# TODO send ack in response
# TODO send handshake init response message
await self._send_ack(sync_bit)
# send handshake init response message
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
@ -228,7 +229,7 @@ class Channel(Context):
self.set_channel_state(ChannelState.TH2)
return
def _handle_state_TH2(self, message_length: int) -> None:
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
print("th2 branche")
host_encrypted_static_pubkey = self.buffer[
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
@ -243,11 +244,11 @@ class Channel(Context):
host_encrypted_static_pubkey,
handshake_completion_request_noise_payload,
) # TODO remove
# TODO send ack in response
# TODO send hanshake completion response
# send hanshake completion response
loop.schedule(
self._write_encrypted_payload_loop(
thp_messages.get_handshake_init_response()
thp_messages.get_handshake_completion_response()
)
)
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
@ -288,6 +289,11 @@ class Channel(Context):
self.create_new_session(message.passphrase)
else:
self.create_new_session()
# TODO reuse existing buffer and compute size dynamically
bufferrone = bytearray(2)
message_size: int = thp_messages.get_new_session_message(bufferrone)
print(message_size) # TODO adjust
loop.schedule(self._write_encrypted_payload_loop(bufferrone))
except Exception as e:
print("Proč??")
print(e)
@ -311,13 +317,20 @@ class Channel(Context):
self.expected_payload_length = 0
self.is_cont_packet_expected = False
async def _sendAck(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
async def _send_ack(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
header = InitHeader(
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
)
chksum = checksum.compute(header.to_bytes())
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH)
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.channel_id,
ack_bit,
)
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0:
@ -335,19 +348,22 @@ class Channel(Context):
# trezor.crypto.noise.encode(key, payload=self.buffer)
# TODO payload_len should be output from trezor.crypto.noise.encode
# TODO payload_len should be output from trezor.crypto.noise.encode, I guess
payload_len = noise_payload_len # + TAG_LENGTH # TODO
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
print("write loop before while")
payload_len = len(payload)
payload_len = len(payload) + CHECKSUM_LENGTH
sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
header = InitHeader(
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
)
chksum = checksum.compute(header.to_bytes() + payload)
payload = payload + chksum
# TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False)
while True:
@ -357,7 +373,7 @@ class Channel(Context):
" send_sync_bit:",
THP.sync_get_send_bit(self.channel_cache),
)
await self._write_encrypted_payload(header, payload, payload_len)
await self._write_payload_to_wire(header, payload, payload_len)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try:
await self.waiting_for_ack_timeout
@ -365,28 +381,29 @@ class Channel(Context):
THP.sync_set_send_bit_to_opposite(self.channel_cache)
break
async def _write_encrypted_payload(
async def _write_payload_to_wire(
self, header: InitHeader, payload: bytes, payload_len: int
):
print("write payload to wire:")
# prepare the report buffer with header data
report = bytearray(REPORT_LENGTH)
header.pack_to_buffer(report)
# write initial report
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
await self._write_report(report)
await self._write_report_to_wire(report)
# if we have more data to write, use continuation reports for it
if nwritten < payload_len:
header.pack_to_cont_buffer(report)
while nwritten < payload_len:
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
await self._write_report(report)
await self._write_report_to_wire(report)
async def _write_report(self, report: utils.BufferType) -> None:
async def _write_report_to_wire(self, report: utils.BufferType) -> None:
while True:
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
printBytes(report) # TODO remove
n = self.iface.write(report)
if n == len(report):
return

View File

@ -2,7 +2,7 @@ import ustruct # pyright:ignore[reportMissingModuleSource]
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import protobuf
from trezor.messages import ThpCreateNewSession
from trezor.messages import ThpCreateNewSession, ThpNewSession
from .. import message_handler
from ..protocol_common import Message
@ -15,6 +15,9 @@ ACK_MESSAGE = 0x20
_ERROR = 0x41
_CHANNEL_ALLOCATION_RES = 0x40
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
class InitHeader:
format_str = ">BHH"
@ -79,7 +82,21 @@ def get_error_unallocated_channel() -> bytes:
def get_handshake_init_response() -> bytes:
return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement
# TODO implement - 32 bytes ephemeral key, 48 bytes encrypted and masked public key, 16 bytes ciphertext of empty string (i.e. noise tag)
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x40\x41\x42\x43\x44\x45\x46\x47\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
def get_handshake_completion_response() -> bytes:
return (
TREZOR_STATE_PAIRED
+ b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
)
def get_new_session_message(buffer: bytearray) -> int:
msg = ThpNewSession(new_session_id=1)
encoded_msg = protobuf.encode(buffer, msg)
return encoded_msg
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: