Fix synchronization

M1nd3r/thp5
M1nd3r 1 month ago
parent ad3b878625
commit 07f6d98a45

@ -6,7 +6,7 @@ from ubinascii import hexlify
import usb
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import io, loop, protobuf, utils
from trezor import io, log, loop, protobuf, utils
from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler
from trezor.wire.thp import thp_messages
@ -41,6 +41,8 @@ MESSAGE_TYPE_LENGTH = const(2)
REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
_ACK_MESSAGE = 0x20
class Channel(Context):
def __init__(self, channel_cache: ChannelCache) -> None:
@ -139,10 +141,13 @@ class Channel(Context):
async def _handle_completed_message(self):
print("handling completed message")
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
msg_len = payload_length + INIT_DATA_OFFSET
print("checksum check")
printBytes(self.buffer)
# printBytes(self.buffer)
if not checksum.is_valid(
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
@ -150,16 +155,40 @@ class Channel(Context):
# checksum is not valid -> ignore message
self._todo_clear_buffer()
return
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
print("sync bit:", sync_bit)
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
self._handle_received_ACK(sync_bit)
self._todo_clear_buffer()
return
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache):
if __debug__:
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.channel_id,
sync_bit,
)
await self._sendAck(sync_bit)
print("___set receive bit to", 1 - sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
state = self.get_channel_state()
_print_state(state)
if __debug__:
log.debug(__name__, _state_to_str(state))
if state is ChannelState.TH1:
if not _is_ctrl_byte_handshake_init:
@ -270,6 +299,21 @@ class Channel(Context):
self.expected_payload_length = 0
self.is_cont_packet_expected = False
async def _sendAck(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
header = InitHeader(
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
)
chksum = checksum.compute(header.to_bytes())
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
@ -287,17 +331,26 @@ class Channel(Context):
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
print("write loop before while")
payload_len = len(payload)
sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
header = InitHeader(
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
)
# TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False)
while True:
print("write encrypted payload loop - start")
print(
"write encrypted payload loop - start, sync_bit:",
header.ctrl_byte & 0x10,
" send_sync_bit:",
THP.sync_get_send_bit(self.channel_cache),
)
await self._write_encrypted_payload(header, payload, payload_len)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try:
await self.waiting_for_ack_timeout
except loop.TaskClosed:
THP.sync_set_send_bit_to_opposite(self.channel_cache)
break
async def _write_encrypted_payload(
@ -459,30 +512,29 @@ def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def _print_state(cs: int) -> None:
if cs == ChannelState.ENCRYPTED_TRANSPORT:
print("state: encrypted transport")
elif cs == ChannelState.TH1:
print("state: th1")
elif cs == ChannelState.TH2:
print("state: th2")
elif cs == ChannelState.TP1:
print("state: tp1")
elif cs == ChannelState.TP2:
print("state: tp2")
elif cs == ChannelState.TP3:
print("state: tp3")
elif cs == ChannelState.TP4:
print("state: tp4")
elif cs == ChannelState.TP5:
print("state: tp5")
elif cs == ChannelState.UNALLOCATED:
print("state: unallocated")
elif cs == ChannelState.UNAUTHENTICATED:
print("state: unauthenticated")
def _state_to_str(state: int) -> str:
if state == ChannelState.ENCRYPTED_TRANSPORT:
return "state: encrypted transport"
elif state == ChannelState.TH1:
return "state: th1"
elif state == ChannelState.TH2:
return "state: th2"
elif state == ChannelState.TP1:
return "state: tp1"
elif state == ChannelState.TP2:
return "state: tp2"
elif state == ChannelState.TP3:
return "state: tp3"
elif state == ChannelState.TP4:
return "state: tp4"
elif state == ChannelState.TP5:
return "state: tp5"
elif state == ChannelState.UNALLOCATED:
return "state: unallocated"
elif state == ChannelState.UNAUTHENTICATED:
return "state: unauthenticated"
else:
print(cs)
print("state: <not implemented printout>")
return "state: <not implemented>"
def printBytes(a):

Loading…
Cancel
Save