diff --git a/core/src/boot.py b/core/src/boot.py index be33be964..68300fc41 100644 --- a/core/src/boot.py +++ b/core/src/boot.py @@ -46,7 +46,6 @@ async def bootscreen() -> None: lockscreen = Lockscreen(label=storage.device.get_label(), bootscreen=True) while True: try: - if can_lock_device(): enforce_welcome_screen_duration() ui.backlight_fade(ui.style.BACKLIGHT_DIM) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 0d36faa72..311b6e932 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -160,8 +160,7 @@ def set_int(key: int, value: int) -> None: if key & SESSIONLESS_FLAG: length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG] - - if active_session is None: + elif active_session is None: raise InvalidSessionError else: length = active_session.fields[key] diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 6deaee0ca..075763149 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -206,7 +206,7 @@ async def handle_session( else: ctx_buffer = WIRE_BUFFER - ctx = context.Context(iface, session_id, ctx_buffer) + ctx = context.Context(iface, ctx_buffer, session_id) next_msg: protocol_common.Message | None = None if __debug__ and is_debug_session: diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 35929c2ad..226edfcbb 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -61,10 +61,15 @@ class Context: (i.e., wire, debug, single BT connection, etc.) """ - def __init__(self, iface: WireInterface, buffer: bytearray) -> None: + def __init__( + self, + iface: WireInterface, + buffer: bytearray, + session_id: bytearray | None = None, + ) -> None: self.iface = iface self.buffer = buffer - self.session_id: bytearray | None = None + self.session_id: session_id def read_from_wire(self) -> Awaitable[Message]: """Read a whole message from the wire without parsing it.""" diff --git a/core/src/trezor/wire/protocol.py b/core/src/trezor/wire/protocol.py index e8dd191e2..c90a6eb2c 100644 --- a/core/src/trezor/wire/protocol.py +++ b/core/src/trezor/wire/protocol.py @@ -10,8 +10,8 @@ if TYPE_CHECKING: class WireProtocol: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message: if utils.USE_THP: - return thp_v1.read_message(iface, buffer) - return codec_v1.read_message(iface, buffer) + return await thp_v1.read_message(iface, buffer) + return await codec_v1.read_message(iface, buffer) async def write_message(iface: WireInterface, message: Message) -> None: if utils.USE_THP: diff --git a/core/src/trezor/wire/thp_session.py b/core/src/trezor/wire/thp_session.py index 203c3bc46..b0a9579c3 100644 --- a/core/src/trezor/wire/thp_session.py +++ b/core/src/trezor/wire/thp_session.py @@ -76,7 +76,7 @@ def get_session_from_id(session_id) -> SessionThpCache | None: return session -def get_state(session: SessionThpCache) -> int: +def get_state(session: SessionThpCache | None) -> int: if session is None: return SessionState.UNALLOCATED return _decode_session_state(session.state) diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 02ab604e7..25efb6b7f 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -66,7 +66,7 @@ class InitHeader: class InterruptingInitPacket: - def __init__(self, report) -> None: + def __init__(self, report: bytes) -> None: self.initReport = report @@ -76,24 +76,23 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag if msg is InterruptingInitPacket: msg = await read_message_or_init_packet(iface, buffer, msg.initReport) else: - raise ThpError("Unexpected output of read_message_or_init_packet") + raise ThpError("Unexpected output of read_message_or_init_packet:") return msg async def read_message_or_init_packet( - iface: WireInterface, buffer: utils.BufferType, firstReport=None + iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None ) -> Message | InterruptingInitPacket: report = firstReport while True: # Wait for an initial report if firstReport is None: report = await _get_loop_wait_read(iface) - # Channel multiplexing ctrl_byte, cid = ustruct.unpack(">BH", report) if cid == BROADCAST_CHANNEL_ID: - _handle_broadcast(iface, ctrl_byte, report) + await _handle_broadcast(iface, ctrl_byte, report) report = None continue @@ -125,7 +124,7 @@ async def read_message_or_init_packet( # Handle message on unallocated channel if session_state == SessionState.UNALLOCATED: - message = _handle_unallocated(iface, cid) + message = await _handle_unallocated(iface, cid) # unallocated should not return regular message, TODO, but it might change if message is not None: return message @@ -145,7 +144,7 @@ async def read_message_or_init_packet( # 2: Handle message with unexpected synchronization bit if sync_bit != THP.sync_get_receive_expected_bit(session): - message = _handle_unexpected_sync_bit(iface, cid, sync_bit) + message = await _handle_unexpected_sync_bit(iface, cid, sync_bit) # unsynchronized messages should not return regular message, TODO, # but it might change with the cancelation message if message is not None: @@ -154,10 +153,10 @@ async def read_message_or_init_packet( continue # 3: Send ACK in response - _sendAck(iface, cid, sync_bit) + _sendAck(iface, cid, sync_bit) # TODO await THP.sync_set_receive_expected_bit(session, 1 - sync_bit) - return _handle_allocated(ctrl_byte, session, payload) + return await _handle_allocated(ctrl_byte, session, payload) def _get_loop_wait_read(iface: WireInterface): @@ -175,7 +174,7 @@ def _get_buffer_for_payload( payload: utils.BufferType = bytearray(payload_length) except MemoryError: payload = bytearray(_REPORT_LENGTH) - raise ("Message too large") + raise ThpError("Message too large") return payload # reuse a part of the supplied buffer @@ -197,7 +196,7 @@ async def _buffer_received_data( # handle broadcast - allows the reading process # to survive interruption by broadcast if cont_cid == BROADCAST_CHANNEL_ID: - _handle_broadcast(iface, cont_ctrl_byte, report) + await _handle_broadcast(iface, cont_ctrl_byte, report) continue # handle unexpected initiation packet @@ -236,7 +235,9 @@ async def write_message( THP.sync_set_send_bit_to_opposite(session) else: # retransmission must have the same sync bit as the previously sent message - ctrl_byte = _add_sync_bit_to_ctrl_byte(ctrl_byte, 1 - THP.sync_get_send_bit()) + ctrl_byte = _add_sync_bit_to_ctrl_byte( + ctrl_byte, 1 - THP.sync_get_send_bit(session) + ) header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH) checksum = _compute_checksum_bytes(header.to_bytes() + payload) @@ -276,7 +277,7 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None: return -def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None: +async def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None: if ctrl_byte != _CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in broadcast channel packet") length, nonce, checksum = ustruct.unpack(">H8s4s", report[3:]) @@ -297,10 +298,10 @@ def _handle_broadcast(iface: WireIntreface, ctrl_byte, report) -> Message | None ) checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data) - write_to_wire(iface, response_header, response_data + checksum) + write_to_wire(iface, response_header, response_data + checksum) # TODO await -def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message: +async def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message: # Parameters session and ctrl_byte will be used to determine if the # communication should be encrypted or not @@ -328,20 +329,20 @@ async def _handle_unallocated(iface, cid) -> Message | None: data = _UNALLOCATED_SESSION_ERROR header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH) checksum = _compute_checksum_bytes(header.to_bytes() + data) - write_to_wire(iface, header, data + checksum) + await write_to_wire(iface, header, data + checksum) async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None: ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit) header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH) checksum = _compute_checksum_bytes(header.to_bytes()) - write_to_wire(iface, header, checksum) + await write_to_wire(iface, header, checksum) -def _handle_unexpected_sync_bit( +async def _handle_unexpected_sync_bit( iface: WireInterface, cid: int, sync_bit: int ) -> Message | None: - _sendAck(iface, cid, sync_bit) + await _sendAck(iface, cid, sync_bit) # TODO handle cancelation messages and messages on allocated channels without synchronization # (some such messages might be handled in the classical "allocated" way, if the sync bit is right) @@ -372,5 +373,5 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit): raise ThpError("Unexpected synchronization bit") -def _compute_checksum_bytes(data: bytearray): +def _compute_checksum_bytes(data: bytes | utils.BufferType): return crc.crc32(data).to_bytes(4, "big")