Fix bugs in synchronization and finding handlers

M1nd3r/thp5
M1nd3r 2 months ago committed by M1nd3r
parent 8c5faff65d
commit d75d875d9f

@ -367,7 +367,7 @@ def set_homescreen() -> None:
def lock_device(interrupt_workflow: bool = True) -> None:
if config.has_pin():
config.lock()
wire.find_handler = get_pinlocked_handler
wire.message_handler.find_handler = get_pinlocked_handler
set_homescreen()
if interrupt_workflow:
workflow.close_others()
@ -403,7 +403,7 @@ async def unlock_device() -> None:
_SCREENSAVER_IS_ON = False
set_homescreen()
wire.find_handler = workflow_handlers.find_registered_handler
wire.message_handler.find_handler = workflow_handlers.find_registered_handler
def get_pinlocked_handler(
@ -462,6 +462,6 @@ def boot() -> None:
reload_settings_from_storage()
if config.is_unlocked():
wire.find_handler = workflow_handlers.find_registered_handler
wire.message_handler.find_handler = workflow_handlers.find_registered_handler
else:
wire.find_handler = get_pinlocked_handler
wire.message_handler.find_handler = get_pinlocked_handler

@ -1,4 +1,5 @@
from storage.cache_thp import SessionThpCache
from trezor import log
from . import thp_session as THP
@ -6,11 +7,17 @@ from . import thp_session as THP
def handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
if _ack_is_not_expected(session):
if __debug__:
log.debug(__name__, "Received unexpected ACK message")
return
if _ack_has_incorrect_sync_bit(session, sync_bit):
if __debug__:
log.debug(__name__, "Received ACK message with wrong sync bit")
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
THP.sync_set_can_send_message(session, True)

@ -90,7 +90,8 @@ def sync_set_receive_expected_bit(session: SessionThpCache, bit: int) -> None:
# set second bit to "bit" value
session.sync &= 0xBF
session.sync |= 0x40
if bit:
session.sync |= 0x40
def sync_set_send_bit_to_opposite(session: SessionThpCache) -> None:

@ -3,7 +3,7 @@ from micropython import const
from typing import TYPE_CHECKING
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
from trezor import io, loop, utils
from trezor import io, log, loop, utils
from .protocol_common import MessageWithId
from .thp import ack_handler, checksum, thp_messages
@ -101,6 +101,8 @@ async def read_message_or_init_packet(
# the sole exception of cid_request which can be handled independently.
if _is_ctrl_byte_continuation(ctrl_byte):
# continuation packet is not expected - ignore
if __debug__:
log.debug(__name__, "Received unexpected continuation packet")
report = None
continue
payload_length = ustruct.unpack(">H", report[3:])[0]
@ -125,6 +127,8 @@ async def read_message_or_init_packet(
if session_state == SessionState.UNALLOCATED:
message = await _handle_unallocated(iface, cid)
# unallocated should not return regular message, TODO, but it might change
if __debug__:
log.debug(__name__, "Channel with id: %d in UNALLOCATED", cid)
if message is not None:
return message
report = None
@ -155,6 +159,13 @@ async def read_message_or_init_packet(
continue
# 3: Send ACK in response
if __debug__:
log.debug(
__name__,
"Writing ACK message to a channel with id: %d, sync bit: %d",
cid,
sync_bit,
)
await _sendAck(iface, cid, sync_bit)
THP.sync_set_receive_expected_bit(session, 1 - sync_bit)
@ -257,6 +268,13 @@ async def write_message(
header = InitHeader(ctrl_byte, cid, payload_length + CHECKSUM_LENGTH)
chksum = checksum.compute(header.to_bytes() + payload)
if __debug__ and message.session_id is not None:
log.debug(
__name__,
"Writing message to a session with id: %d, message type: %d, ",
int.from_bytes(message.session_id, "big"),
message.type,
)
await write_to_wire(iface, header, payload + chksum)
# TODO set timeout for retransmission
@ -299,6 +317,8 @@ async def _handle_broadcast(
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
if __debug__:
log.debug(__name__, "Received valid message on broadcast channel ")
length, nonce = ustruct.unpack(">H8s", packet[3:])
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
@ -314,6 +334,8 @@ async def _handle_broadcast(
len(response_data) + CHECKSUM_LENGTH,
)
chksum = checksum.compute(response_header.to_bytes() + response_data)
if __debug__:
log.debug(__name__, "New channel allocated with id %d", channel_id)
await write_to_wire(iface, response_header, response_data + chksum)
@ -328,6 +350,8 @@ async def _handle_allocated(
# trim message type and checksum from payload
message_data = payload[2:-CHECKSUM_LENGTH]
if __debug__:
log.debug(__name__, "Received valid message with type %d", message_type)
return MessageWithId(message_type, message_data, session.session_id)
@ -348,6 +372,8 @@ async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
async def _handle_unexpected_sync_bit(
iface: WireInterface, cid: int, sync_bit: int
) -> MessageWithId | None:
if __debug__:
log.debug(__name__, "Received message has unexpected synchronization bit")
await _sendAck(iface, cid, sync_bit)
# TODO handle cancelation messages and messages on allocated channels without synchronization

Loading…
Cancel
Save