mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
fix(core): fix synchronization
This commit is contained in:
parent
e55e9a4e15
commit
aa2115542b
@ -6,7 +6,7 @@ from ubinascii import hexlify
|
||||
import usb
|
||||
from storage import cache_thp
|
||||
from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache
|
||||
from trezor import io, loop, protobuf, utils
|
||||
from trezor import io, log, loop, protobuf, utils
|
||||
from trezor.messages import ThpCreateNewSession
|
||||
from trezor.wire import message_handler
|
||||
from trezor.wire.thp import thp_messages
|
||||
@ -41,6 +41,8 @@ MESSAGE_TYPE_LENGTH = const(2)
|
||||
REPORT_LENGTH = const(64)
|
||||
MAX_PAYLOAD_LEN = const(60000)
|
||||
|
||||
_ACK_MESSAGE = 0x20
|
||||
|
||||
|
||||
class Channel(Context):
|
||||
def __init__(self, channel_cache: ChannelCache) -> None:
|
||||
@ -139,10 +141,13 @@ class Channel(Context):
|
||||
|
||||
async def _handle_completed_message(self):
|
||||
print("handling completed message")
|
||||
print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache))
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer)
|
||||
msg_len = payload_length + INIT_DATA_OFFSET
|
||||
|
||||
print("checksum check")
|
||||
printBytes(self.buffer)
|
||||
# printBytes(self.buffer)
|
||||
|
||||
if not checksum.is_valid(
|
||||
checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len],
|
||||
data=self.buffer[: msg_len - CHECKSUM_LENGTH],
|
||||
@ -150,15 +155,41 @@ class Channel(Context):
|
||||
# checksum is not valid -> ignore message
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
print("sync bit")
|
||||
|
||||
# Synchronization process
|
||||
sync_bit = (ctrl_byte & 0x10) >> 4
|
||||
print("sync bit:", sync_bit)
|
||||
|
||||
# 1: Handle ACKs
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
self._handle_received_ACK(sync_bit)
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
|
||||
# 2: Handle message with unexpected synchronization bit
|
||||
if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache):
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "Received message with an unexpected synchronization bit"
|
||||
)
|
||||
await self._sendAck(sync_bit)
|
||||
raise ThpError("Received message with an unexpected synchronization bit")
|
||||
|
||||
# 3: Send ACK in response
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"Writing ACK message to a channel with id: %d, sync bit: %d",
|
||||
self.channel_id,
|
||||
sync_bit,
|
||||
)
|
||||
await self._sendAck(sync_bit)
|
||||
print("___set receive bit to", 1 - sync_bit)
|
||||
THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit)
|
||||
|
||||
state = self.get_channel_state()
|
||||
_print_state(state)
|
||||
if __debug__:
|
||||
log.debug(__name__, _state_to_str(state))
|
||||
|
||||
if state is ChannelState.TH1:
|
||||
if not _is_ctrl_byte_handshake_init:
|
||||
@ -269,6 +300,21 @@ class Channel(Context):
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
async def _sendAck(self, ack_bit: int) -> None:
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
||||
header = InitHeader(
|
||||
ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH
|
||||
)
|
||||
chksum = checksum.compute(header.to_bytes())
|
||||
await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH)
|
||||
|
||||
def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit):
|
||||
if sync_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if sync_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise ThpError("Unexpected synchronization bit")
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
@ -286,17 +332,26 @@ class Channel(Context):
|
||||
async def _write_encrypted_payload_loop(self, payload: bytes) -> None:
|
||||
print("write loop before while")
|
||||
payload_len = len(payload)
|
||||
sync_bit = THP.sync_get_send_bit(self.channel_cache)
|
||||
ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit)
|
||||
header = InitHeader(
|
||||
ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len
|
||||
ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len
|
||||
)
|
||||
# TODO add condition that disallows to write when can_send_message is false
|
||||
THP.sync_set_can_send_message(self.channel_cache, False)
|
||||
while True:
|
||||
print("write encrypted payload loop - start")
|
||||
print(
|
||||
"write encrypted payload loop - start, sync_bit:",
|
||||
header.ctrl_byte & 0x10,
|
||||
" send_sync_bit:",
|
||||
THP.sync_get_send_bit(self.channel_cache),
|
||||
)
|
||||
await self._write_encrypted_payload(header, payload, payload_len)
|
||||
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
|
||||
try:
|
||||
await self.waiting_for_ack_timeout
|
||||
except loop.TaskClosed:
|
||||
THP.sync_set_send_bit_to_opposite(self.channel_cache)
|
||||
break
|
||||
|
||||
async def _write_encrypted_payload(
|
||||
@ -455,30 +510,29 @@ def _is_ctrl_byte_ack(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == ACK_MESSAGE
|
||||
|
||||
|
||||
def _print_state(cs: int) -> None:
|
||||
if cs == ChannelState.ENCRYPTED_TRANSPORT:
|
||||
print("state: encrypted transport")
|
||||
elif cs == ChannelState.TH1:
|
||||
print("state: th1")
|
||||
elif cs == ChannelState.TH2:
|
||||
print("state: th2")
|
||||
elif cs == ChannelState.TP1:
|
||||
print("state: tp1")
|
||||
elif cs == ChannelState.TP2:
|
||||
print("state: tp2")
|
||||
elif cs == ChannelState.TP3:
|
||||
print("state: tp3")
|
||||
elif cs == ChannelState.TP4:
|
||||
print("state: tp4")
|
||||
elif cs == ChannelState.TP5:
|
||||
print("state: tp5")
|
||||
elif cs == ChannelState.UNALLOCATED:
|
||||
print("state: unallocated")
|
||||
elif cs == ChannelState.UNAUTHENTICATED:
|
||||
print("state: unauthenticated")
|
||||
def _state_to_str(state: int) -> str:
|
||||
if state == ChannelState.ENCRYPTED_TRANSPORT:
|
||||
return "state: encrypted transport"
|
||||
elif state == ChannelState.TH1:
|
||||
return "state: th1"
|
||||
elif state == ChannelState.TH2:
|
||||
return "state: th2"
|
||||
elif state == ChannelState.TP1:
|
||||
return "state: tp1"
|
||||
elif state == ChannelState.TP2:
|
||||
return "state: tp2"
|
||||
elif state == ChannelState.TP3:
|
||||
return "state: tp3"
|
||||
elif state == ChannelState.TP4:
|
||||
return "state: tp4"
|
||||
elif state == ChannelState.TP5:
|
||||
return "state: tp5"
|
||||
elif state == ChannelState.UNALLOCATED:
|
||||
return "state: unallocated"
|
||||
elif state == ChannelState.UNAUTHENTICATED:
|
||||
return "state: unauthenticated"
|
||||
else:
|
||||
print(cs)
|
||||
print("state: <not implemented printout>")
|
||||
return "state: <not implemented>"
|
||||
|
||||
|
||||
def printBytes(a):
|
||||
|
@ -132,7 +132,8 @@ def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None:
|
||||
|
||||
# set third bit to "bit" value
|
||||
cache.sync &= 0xDF
|
||||
cache.sync |= 0x20
|
||||
if bit:
|
||||
cache.sync |= 0x20
|
||||
|
||||
|
||||
def _decode_session_state(state: bytearray) -> int:
|
||||
|
Loading…
Reference in New Issue
Block a user