mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): create new session response
This commit is contained in:
parent
e7f5f3d3f2
commit
775ad59630
@ -41,8 +41,6 @@ MESSAGE_TYPE_LENGTH = const(2)
|
||||
REPORT_LENGTH = const(64)
|
||||
MAX_PAYLOAD_LEN = const(60000)
|
||||
|
||||
_ACK_MESSAGE = 0x20
|
||||
|
||||
|
||||
class Channel(Context):
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
@ -163,21 +161,17 @@ class Channel(Context):
|
||||
log.debug(
|
||||
__name__, "Received message with an unexpected synchronization bit"
|
||||
)
|
||||
await self._sendAck(sync_bit)
|
||||
await self._send_ack(sync_bit)
|
||||
raise ThpError("Received message with an unexpected synchronization bit")
|
||||
|
||||
# 3: Send ACK in response
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
||||
self.channel_id,
|
||||
sync_bit,
|
||||
)
|
||||
await self._sendAck(sync_bit)
|
||||
await self._send_ack(sync_bit)
|
||||
|
||||
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
|
||||
|
||||
self._handle_valid_message(payload_length, message_length, ctrl_byte)
|
||||
await self._handle_valid_message(
|
||||
payload_length, message_length, ctrl_byte, sync_bit
|
||||
)
|
||||
print("end handle completed message")
|
||||
|
||||
def _check_checksum(self, message_length: int):
|
||||
@ -189,15 +183,16 @@ class Channel(Context):
|
||||
self._todo_clear_buffer()
|
||||
raise ThpError("Invalid checksum, ignoring message.")
|
||||
|
||||
def _handle_valid_message(
|
||||
self, payload_length: int, message_length: int, ctrl_byte: int
|
||||
async def _handle_valid_message(
|
||||
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
|
||||
) -> None:
|
||||
state = self.get_channel_state()
|
||||
if __debug__:
|
||||
log.debug(__name__, _state_to_str(state))
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
self._handle_state_TH1(payload_length, message_length)
|
||||
await self._handle_state_TH1(payload_length, message_length, sync_bit)
|
||||
return
|
||||
|
||||
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
||||
self._todo_clear_buffer()
|
||||
@ -205,11 +200,15 @@ class Channel(Context):
|
||||
|
||||
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
||||
self._handle_state_ENCRYPTED_TRANSPORT(message_length)
|
||||
return
|
||||
|
||||
if state is ChannelState.TH2:
|
||||
self._handle_state_TH2(message_length)
|
||||
await self._handle_state_TH2(message_length, sync_bit)
|
||||
return
|
||||
|
||||
def _handle_state_TH1(self, payload_length: int, message_length: int) -> None:
|
||||
async def _handle_state_TH1(
|
||||
self, payload_length: int, message_length: int, sync_bit: int
|
||||
) -> None:
|
||||
if not _is_ctrl_byte_handshake_init:
|
||||
raise ThpError("Message received is not a handshake init request!")
|
||||
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
@ -218,8 +217,10 @@ class Channel(Context):
|
||||
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
|
||||
)
|
||||
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
|
||||
# TODO send ack in response
|
||||
# TODO send handshake init response message
|
||||
|
||||
await self._send_ack(sync_bit)
|
||||
|
||||
# send handshake init response message
|
||||
loop.schedule(
|
||||
self._write_encrypted_payload_loop(
|
||||
thp_messages.get_handshake_init_response()
|
||||
@ -228,7 +229,7 @@ class Channel(Context):
|
||||
self.set_channel_state(ChannelState.TH2)
|
||||
return
|
||||
|
||||
def _handle_state_TH2(self, message_length: int) -> None:
|
||||
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
|
||||
print("th2 branche")
|
||||
host_encrypted_static_pubkey = self.buffer[
|
||||
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
||||
@ -243,11 +244,11 @@ class Channel(Context):
|
||||
host_encrypted_static_pubkey,
|
||||
handshake_completion_request_noise_payload,
|
||||
) # TODO remove
|
||||
# TODO send ack in response
|
||||
# TODO send hanshake completion response
|
||||
|
||||
# send hanshake completion response
|
||||
loop.schedule(
|
||||
self._write_encrypted_payload_loop(
|
||||
thp_messages.get_handshake_init_response()
|
||||
thp_messages.get_handshake_completion_response()
|
||||
)
|
||||
)
|
||||
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
||||
@ -288,6 +289,11 @@ class Channel(Context):
|
||||
self.create_new_session(message.passphrase)
|
||||
else:
|
||||
self.create_new_session()
|
||||
# TODO reuse existing buffer and compute size dynamically
|
||||
bufferrone = bytearray(2)
|
||||
message_size: int = thp_messages.get_new_session_message(bufferrone)
|
||||
print(message_size) # TODO adjust
|
||||
loop.schedule(self._write_encrypted_payload_loop(bufferrone))
|
||||
except Exception as e:
|
||||
print("Proč??")
|
||||
print(e)
|
||||
@ -311,13 +317,20 @@ class Channel(Context):
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
async def _sendAck(self, ack_bit: int) -> None:
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
||||
async def _send_ack(self, ack_bit: int) -> None:
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
||||
header = InitHeader(
|
||||
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
|
||||
)
|
||||
chksum = checksum.compute(header.to_bytes())
|
||||
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
||||
self.channel_id,
|
||||
ack_bit,
|
||||
)
|
||||
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
|
||||
|
||||
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
@ -335,19 +348,22 @@ class Channel(Context):
|
||||
|
||||
# trezor.crypto.noise.encode(key, payload=self.buffer)
|
||||
|
||||
# TODO payload_len should be output from trezor.crypto.noise.encode
|
||||
# TODO payload_len should be output from trezor.crypto.noise.encode, I guess
|
||||
payload_len = noise_payload_len # + TAG_LENGTH # TODO
|
||||
|
||||
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
|
||||
|
||||
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
||||
print("write loop before while")
|
||||
payload_len = len(payload)
|
||||
payload_len = len(payload) + CHECKSUM_LENGTH
|
||||
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
|
||||
header = InitHeader(
|
||||
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
|
||||
)
|
||||
chksum = checksum.compute(header.to_bytes() + payload)
|
||||
payload = payload + chksum
|
||||
|
||||
# TODO add condition that disallows to write when can_send_message is false
|
||||
THP.sync_set_can_send_message(self.channel_cache, False)
|
||||
while True:
|
||||
@ -357,7 +373,7 @@ class Channel(Context):
|
||||
" send_sync_bit:",
|
||||
THP.sync_get_send_bit(self.channel_cache),
|
||||
)
|
||||
await self._write_encrypted_payload(header, payload, payload_len)
|
||||
await self._write_payload_to_wire(header, payload, payload_len)
|
||||
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
||||
try:
|
||||
await self.waiting_for_ack_timeout
|
||||
@ -365,28 +381,29 @@ class Channel(Context):
|
||||
THP.sync_set_send_bit_to_opposite(self.channel_cache)
|
||||
break
|
||||
|
||||
async def _write_encrypted_payload(
|
||||
async def _write_payload_to_wire(
|
||||
self, header: InitHeader, payload: bytes, payload_len: int
|
||||
):
|
||||
|
||||
print("write payload to wire:")
|
||||
# prepare the report buffer with header data
|
||||
report = bytearray(REPORT_LENGTH)
|
||||
header.pack_to_buffer(report)
|
||||
|
||||
# write initial report
|
||||
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
||||
await self._write_report(report)
|
||||
await self._write_report_to_wire(report)
|
||||
|
||||
# if we have more data to write, use continuation reports for it
|
||||
if nwritten < payload_len:
|
||||
header.pack_to_cont_buffer(report)
|
||||
while nwritten < payload_len:
|
||||
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
||||
await self._write_report(report)
|
||||
await self._write_report_to_wire(report)
|
||||
|
||||
async def _write_report(self, report: utils.BufferType) -> None:
|
||||
async def _write_report_to_wire(self, report: utils.BufferType) -> None:
|
||||
while True:
|
||||
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
||||
printBytes(report) # TODO remove
|
||||
n = self.iface.write(report)
|
||||
if n == len(report):
|
||||
return
|
||||
|
@ -2,7 +2,7 @@ import ustruct # pyright:ignore[reportMissingModuleSource]
|
||||
|
||||
from storage.cache_thp import BROADCAST_CHANNEL_ID
|
||||
from trezor import protobuf
|
||||
from trezor.messages import ThpCreateNewSession
|
||||
from trezor.messages import ThpCreateNewSession, ThpNewSession
|
||||
|
||||
from .. import message_handler
|
||||
from ..protocol_common import Message
|
||||
@ -15,6 +15,9 @@ ACK_MESSAGE = 0x20
|
||||
_ERROR = 0x41
|
||||
_CHANNEL_ALLOCATION_RES = 0x40
|
||||
|
||||
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
TREZOR_STATE_PAIRED = b"\x01"
|
||||
|
||||
|
||||
class InitHeader:
|
||||
format_str = ">BHH"
|
||||
@ -79,7 +82,21 @@ def get_error_unallocated_channel() -> bytes:
|
||||
|
||||
|
||||
def get_handshake_init_response() -> bytes:
|
||||
return b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" # TODO implement
|
||||
# TODO implement - 32 bytes ephemeral key, 48 bytes encrypted and masked public key, 16 bytes ciphertext of empty string (i.e. noise tag)
|
||||
return b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x20\x21\x22\x23\x24\x25\x26\x27\x28\x29\x30\x31\x32\x33\x34\x35\x36\x37\x38\x39\x40\x41\x42\x43\x44\x45\x46\x47\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
|
||||
|
||||
|
||||
def get_handshake_completion_response() -> bytes:
|
||||
return (
|
||||
TREZOR_STATE_PAIRED
|
||||
+ b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
|
||||
)
|
||||
|
||||
|
||||
def get_new_session_message(buffer: bytearray) -> int:
|
||||
msg = ThpNewSession(new_session_id=1)
|
||||
encoded_msg = protobuf.encode(buffer, msg)
|
||||
return encoded_msg
|
||||
|
||||
|
||||
def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType:
|
||||
|
Loading…
Reference in New Issue
Block a user