|
|
|
@ -41,8 +41,6 @@ MESSAGE_TYPE_LENGTH = const(2)
|
|
|
|
|
REPORT_LENGTH = const(64)
|
|
|
|
|
MAX_PAYLOAD_LEN = const(60000)
|
|
|
|
|
|
|
|
|
|
_ACK_MESSAGE = 0x20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Channel(Context):
|
|
|
|
|
def __init__(self, channel_cache: ChannelCache) -> None:
|
|
|
|
@ -163,21 +161,17 @@ class Channel(Context):
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__, "Received message with an unexpected synchronization bit"
|
|
|
|
|
)
|
|
|
|
|
await self._sendAck(sync_bit)
|
|
|
|
|
await self._send_ack(sync_bit)
|
|
|
|
|
raise ThpError("Received message with an unexpected synchronization bit")
|
|
|
|
|
|
|
|
|
|
# 3: Send ACK in response
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__,
|
|
|
|
|
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
|
|
|
|
self.channel_id,
|
|
|
|
|
sync_bit,
|
|
|
|
|
)
|
|
|
|
|
await self._sendAck(sync_bit)
|
|
|
|
|
await self._send_ack(sync_bit)
|
|
|
|
|
|
|
|
|
|
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
|
|
|
|
|
|
|
|
|
|
self._handle_valid_message(payload_length, message_length, ctrl_byte)
|
|
|
|
|
await self._handle_valid_message(
|
|
|
|
|
payload_length, message_length, ctrl_byte, sync_bit
|
|
|
|
|
)
|
|
|
|
|
print("end handle completed message")
|
|
|
|
|
|
|
|
|
|
def _check_checksum(self, message_length: int):
|
|
|
|
@ -189,15 +183,16 @@ class Channel(Context):
|
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
|
raise ThpError("Invalid checksum, ignoring message.")
|
|
|
|
|
|
|
|
|
|
def _handle_valid_message(
|
|
|
|
|
self, payload_length: int, message_length: int, ctrl_byte: int
|
|
|
|
|
async def _handle_valid_message(
|
|
|
|
|
self, payload_length: int, message_length: int, ctrl_byte: int, sync_bit: int
|
|
|
|
|
) -> None:
|
|
|
|
|
state = self.get_channel_state()
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(__name__, _state_to_str(state))
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH1:
|
|
|
|
|
self._handle_state_TH1(payload_length, message_length)
|
|
|
|
|
await self._handle_state_TH1(payload_length, message_length, sync_bit)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
|
|
|
|
self._todo_clear_buffer()
|
|
|
|
@ -205,11 +200,15 @@ class Channel(Context):
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.ENCRYPTED_TRANSPORT:
|
|
|
|
|
self._handle_state_ENCRYPTED_TRANSPORT(message_length)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if state is ChannelState.TH2:
|
|
|
|
|
self._handle_state_TH2(message_length)
|
|
|
|
|
await self._handle_state_TH2(message_length, sync_bit)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def _handle_state_TH1(self, payload_length: int, message_length: int) -> None:
|
|
|
|
|
async def _handle_state_TH1(
|
|
|
|
|
self, payload_length: int, message_length: int, sync_bit: int
|
|
|
|
|
) -> None:
|
|
|
|
|
if not _is_ctrl_byte_handshake_init:
|
|
|
|
|
raise ThpError("Message received is not a handshake init request!")
|
|
|
|
|
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
|
|
|
@ -218,8 +217,10 @@ class Channel(Context):
|
|
|
|
|
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
|
|
|
|
|
)
|
|
|
|
|
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
|
|
|
|
|
# TODO send ack in response
|
|
|
|
|
# TODO send handshake init response message
|
|
|
|
|
|
|
|
|
|
await self._send_ack(sync_bit)
|
|
|
|
|
|
|
|
|
|
# send handshake init response message
|
|
|
|
|
loop.schedule(
|
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
|
thp_messages.get_handshake_init_response()
|
|
|
|
@ -228,7 +229,7 @@ class Channel(Context):
|
|
|
|
|
self.set_channel_state(ChannelState.TH2)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def _handle_state_TH2(self, message_length: int) -> None:
|
|
|
|
|
async def _handle_state_TH2(self, message_length: int, sync_bit: int) -> None:
|
|
|
|
|
print("th2 branche")
|
|
|
|
|
host_encrypted_static_pubkey = self.buffer[
|
|
|
|
|
INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH
|
|
|
|
@ -243,11 +244,11 @@ class Channel(Context):
|
|
|
|
|
host_encrypted_static_pubkey,
|
|
|
|
|
handshake_completion_request_noise_payload,
|
|
|
|
|
) # TODO remove
|
|
|
|
|
# TODO send ack in response
|
|
|
|
|
# TODO send hanshake completion response
|
|
|
|
|
|
|
|
|
|
# send hanshake completion response
|
|
|
|
|
loop.schedule(
|
|
|
|
|
self._write_encrypted_payload_loop(
|
|
|
|
|
thp_messages.get_handshake_init_response()
|
|
|
|
|
thp_messages.get_handshake_completion_response()
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
|
|
|
|
@ -288,6 +289,11 @@ class Channel(Context):
|
|
|
|
|
self.create_new_session(message.passphrase)
|
|
|
|
|
else:
|
|
|
|
|
self.create_new_session()
|
|
|
|
|
# TODO reuse existing buffer and compute size dynamically
|
|
|
|
|
bufferrone = bytearray(2)
|
|
|
|
|
message_type: int = thp_messages.get_new_session_message(bufferrone)
|
|
|
|
|
|
|
|
|
|
loop.schedule(self._write_encrypted_payload_loop(bufferrone))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print("Proč??")
|
|
|
|
|
print(e)
|
|
|
|
@ -311,13 +317,20 @@ class Channel(Context):
|
|
|
|
|
self.expected_payload_length = 0
|
|
|
|
|
self.is_cont_packet_expected = False
|
|
|
|
|
|
|
|
|
|
async def _sendAck(self, ack_bit: int) -> None:
|
|
|
|
|
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
|
|
|
|
async def _send_ack(self, ack_bit: int) -> None:
|
|
|
|
|
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ACK_MESSAGE, ack_bit)
|
|
|
|
|
header = InitHeader(
|
|
|
|
|
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
|
|
|
|
|
)
|
|
|
|
|
chksum = checksum.compute(header.to_bytes())
|
|
|
|
|
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH)
|
|
|
|
|
if __debug__:
|
|
|
|
|
log.debug(
|
|
|
|
|
__name__,
|
|
|
|
|
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
|
|
|
|
self.channel_id,
|
|
|
|
|
ack_bit,
|
|
|
|
|
)
|
|
|
|
|
await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH)
|
|
|
|
|
|
|
|
|
|
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
|
|
|
|
|
if sync_bit == 0:
|
|
|
|
@ -335,19 +348,22 @@ class Channel(Context):
|
|
|
|
|
|
|
|
|
|
# trezor.crypto.noise.encode(key, payload=self.buffer)
|
|
|
|
|
|
|
|
|
|
# TODO payload_len should be output from trezor.crypto.noise.encode
|
|
|
|
|
# TODO payload_len should be output from trezor.crypto.noise.encode, I guess
|
|
|
|
|
payload_len = noise_payload_len # + TAG_LENGTH # TODO
|
|
|
|
|
|
|
|
|
|
loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len]))
|
|
|
|
|
|
|
|
|
|
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
|
|
|
|
print("write loop before while")
|
|
|
|
|
payload_len = len(payload)
|
|
|
|
|
payload_len = len(payload) + CHECKSUM_LENGTH
|
|
|
|
|
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
|
|
|
|
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
|
|
|
|
|
header = InitHeader(
|
|
|
|
|
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
|
|
|
|
|
)
|
|
|
|
|
chksum = checksum.compute(header.to_bytes() + payload)
|
|
|
|
|
payload = payload + chksum
|
|
|
|
|
|
|
|
|
|
# TODO add condition that disallows to write when can_send_message is false
|
|
|
|
|
THP.sync_set_can_send_message(self.channel_cache, False)
|
|
|
|
|
while True:
|
|
|
|
@ -357,7 +373,7 @@ class Channel(Context):
|
|
|
|
|
" send_sync_bit:",
|
|
|
|
|
THP.sync_get_send_bit(self.channel_cache),
|
|
|
|
|
)
|
|
|
|
|
await self._write_encrypted_payload(header, payload, payload_len)
|
|
|
|
|
await self._write_payload_to_wire(header, payload, payload_len)
|
|
|
|
|
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
|
|
|
|
try:
|
|
|
|
|
await self.waiting_for_ack_timeout
|
|
|
|
@ -365,28 +381,29 @@ class Channel(Context):
|
|
|
|
|
THP.sync_set_send_bit_to_opposite(self.channel_cache)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
async def _write_encrypted_payload(
|
|
|
|
|
async def _write_payload_to_wire(
|
|
|
|
|
self, header: InitHeader, payload: bytes, payload_len: int
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
print("write payload to wire:")
|
|
|
|
|
# prepare the report buffer with header data
|
|
|
|
|
report = bytearray(REPORT_LENGTH)
|
|
|
|
|
header.pack_to_buffer(report)
|
|
|
|
|
|
|
|
|
|
# write initial report
|
|
|
|
|
nwritten = utils.memcpy(report, INIT_DATA_OFFSET, payload, 0)
|
|
|
|
|
await self._write_report(report)
|
|
|
|
|
await self._write_report_to_wire(report)
|
|
|
|
|
|
|
|
|
|
# if we have more data to write, use continuation reports for it
|
|
|
|
|
if nwritten < payload_len:
|
|
|
|
|
header.pack_to_cont_buffer(report)
|
|
|
|
|
while nwritten < payload_len:
|
|
|
|
|
nwritten += utils.memcpy(report, CONT_DATA_OFFSET, payload, nwritten)
|
|
|
|
|
await self._write_report(report)
|
|
|
|
|
await self._write_report_to_wire(report)
|
|
|
|
|
|
|
|
|
|
async def _write_report(self, report: utils.BufferType) -> None:
|
|
|
|
|
async def _write_report_to_wire(self, report: utils.BufferType) -> None:
|
|
|
|
|
while True:
|
|
|
|
|
await loop.wait(self.iface.iface_num() | io.POLL_WRITE)
|
|
|
|
|
printBytes(report) # TODO remove
|
|
|
|
|
n = self.iface.write(report)
|
|
|
|
|
if n == len(report):
|
|
|
|
|
return
|
|
|
|
|