1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

fix(core): fix synchronization

This commit is contained in:
M1nd3r 2024-04-02 16:44:12 +02:00
parent e55e9a4e15
commit aa2115542b
2 changed files with 85 additions and 30 deletions

View File

@ -6,7 +6,7 @@ from ubinascii import hexlify
import usb
from storage import cache_thp
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
from trezor import io, loop, protobuf, utils
from trezor import io, log, loop, protobuf, utils
from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler
from trezor.wire.thp import thp_messages
@ -41,6 +41,8 @@ MESSAGE_TYPE_LENGTH = const(2)
REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000)
_ACK_MESSAGE = 0x20
class Channel(Context):
def __init__(self, channel_cache: ChannelCache) -> None:
@ -139,10 +141,13 @@ class Channel(Context):
async def _handle_completed_message(self):
print("handling completed message")
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
msg_len = payload_length + INIT_DATA_OFFSET
print("checksum check")
printBytes(self.buffer)
# printBytes(self.buffer)
if not checksum.is_valid(
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
@ -150,15 +155,41 @@ class Channel(Context):
# checksum is not valid -> ignore message
self._todo_clear_buffer()
return
print("sync bit")
# Synchronization process
sync_bit = (ctrl_byte & 0x10) >> 4
print("sync bit:", sync_bit)
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
self._handle_received_ACK(sync_bit)
self._todo_clear_buffer()
return
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache):
if __debug__:
log.debug(
__name__, "Received message with an unexpected synchronization bit"
)
await self._sendAck(sync_bit)
raise ThpError("Received message with an unexpected synchronization bit")
# 3: Send ACK in response
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
self.channel_id,
sync_bit,
)
await self._sendAck(sync_bit)
print("___set receive bit to", 1 - sync_bit)
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
state = self.get_channel_state()
_print_state(state)
if __debug__:
log.debug(__name__, _state_to_str(state))
if state is ChannelState.TH1:
if not _is_ctrl_byte_handshake_init:
@ -269,6 +300,21 @@ class Channel(Context):
self.expected_payload_length = 0
self.is_cont_packet_expected = False
async def _sendAck(self, ack_bit: int) -> None:
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
header = InitHeader(
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
)
chksum = checksum.compute(header.to_bytes())
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH)
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
if sync_bit == 0:
return ctrl_byte & 0xEF
if sync_bit == 1:
return ctrl_byte | 0x10
raise ThpError("Unexpected synchronization bit")
# CALLED BY WORKFLOW / SESSION CONTEXT
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
@ -286,17 +332,26 @@ class Channel(Context):
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
print("write loop before while")
payload_len = len(payload)
sync_bit = THP.sync_get_send_bit(self.channel_cache)
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
header = InitHeader(
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
)
# TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False)
while True:
print("write encrypted payload loop - start")
print(
"write encrypted payload loop - start, sync_bit:",
header.ctrl_byte & 0x10,
" send_sync_bit:",
THP.sync_get_send_bit(self.channel_cache),
)
await self._write_encrypted_payload(header, payload, payload_len)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try:
await self.waiting_for_ack_timeout
except loop.TaskClosed:
THP.sync_set_send_bit_to_opposite(self.channel_cache)
break
async def _write_encrypted_payload(
@ -455,30 +510,29 @@ def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
return ctrl_byte & 0xEF == ACK_MESSAGE
def _print_state(cs: int) -> None:
if cs == ChannelState.ENCRYPTED_TRANSPORT:
print("state: encrypted transport")
elif cs == ChannelState.TH1:
print("state: th1")
elif cs == ChannelState.TH2:
print("state: th2")
elif cs == ChannelState.TP1:
print("state: tp1")
elif cs == ChannelState.TP2:
print("state: tp2")
elif cs == ChannelState.TP3:
print("state: tp3")
elif cs == ChannelState.TP4:
print("state: tp4")
elif cs == ChannelState.TP5:
print("state: tp5")
elif cs == ChannelState.UNALLOCATED:
print("state: unallocated")
elif cs == ChannelState.UNAUTHENTICATED:
print("state: unauthenticated")
def _state_to_str(state: int) -> str:
if state == ChannelState.ENCRYPTED_TRANSPORT:
return "state: encrypted transport"
elif state == ChannelState.TH1:
return "state: th1"
elif state == ChannelState.TH2:
return "state: th2"
elif state == ChannelState.TP1:
return "state: tp1"
elif state == ChannelState.TP2:
return "state: tp2"
elif state == ChannelState.TP3:
return "state: tp3"
elif state == ChannelState.TP4:
return "state: tp4"
elif state == ChannelState.TP5:
return "state: tp5"
elif state == ChannelState.UNALLOCATED:
return "state: unallocated"
elif state == ChannelState.UNAUTHENTICATED:
return "state: unauthenticated"
else:
print(cs)
print("state: <not implemented printout>")
return "state: <not implemented>"
def printBytes(a):

View File

@ -132,7 +132,8 @@ def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None:
# set third bit to "bit" value
cache.sync &= 0xDF
cache.sync |= 0x20
if bit:
cache.sync |= 0x20
def _decode_session_state(state: bytearray) -> int: