1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-05 13:01:12 +00:00

feat(core): introduce timeout to loop.wait()

This commit is contained in:
matejcik 2023-06-29 17:08:13 +02:00 committed by matejcik
parent 1efb684544
commit b7b09b8836
2 changed files with 51 additions and 11 deletions

View File

@ -374,6 +374,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
desc_cont = frame_cont() desc_cont = frame_cont()
read = loop.wait(iface.iface_num() | io.POLL_READ) read = loop.wait(iface.iface_num() | io.POLL_READ)
# wait for incoming comand indefinitely
buf = await read buf = await read
while True: while True:
ifrm = overlay_struct(bytearray(buf), desc_init) ifrm = overlay_struct(bytearray(buf), desc_init)
@ -409,9 +410,12 @@ async def _read_cmd(iface: HID) -> Cmd | None:
else: else:
data = data[:bcnt] data = data[:bcnt]
# set a timeout for subsequent reads
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
while datalen < bcnt: while datalen < bcnt:
buf = await loop.race(read, loop.sleep(_CTAP_HID_TIMEOUT_MS)) try:
if not isinstance(buf, bytes): buf = await read
except loop.Timeout:
if __debug__: if __debug__:
warning(__name__, "_ERR_MSG_TIMEOUT") warning(__name__, "_ERR_MSG_TIMEOUT")
await send_cmd(cmd_error(ifrm_cid, _ERR_MSG_TIMEOUT), iface) await send_cmd(cmd_error(ifrm_cid, _ERR_MSG_TIMEOUT), iface)
@ -494,7 +498,9 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None:
if offset < datalen: if offset < datalen:
frm = overlay_struct(buf, cont_desc) frm = overlay_struct(buf, cont_desc)
write = loop.wait(iface.iface_num() | io.POLL_WRITE) write = loop.wait(
iface.iface_num() | io.POLL_WRITE, timeout_ms=_CTAP_HID_TIMEOUT_MS
)
while offset < datalen: while offset < datalen:
frm.seq = seq frm.seq = seq
copied = utils.memcpy(frm.data, 0, cmd.data, offset, datalen) copied = utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
@ -502,10 +508,7 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None:
if copied < _FRAME_CONT_SIZE: if copied < _FRAME_CONT_SIZE:
frm.data[copied:] = bytearray(_FRAME_CONT_SIZE - copied) frm.data[copied:] = bytearray(_FRAME_CONT_SIZE - copied)
while True: while True:
ret = await loop.race(write, loop.sleep(_CTAP_HID_TIMEOUT_MS)) await write
if ret is not None:
raise TimeoutError
if iface.write(buf) > 0: if iface.write(buf) > 0:
break break
seq += 1 seq += 1

View File

@ -14,9 +14,9 @@ from typing import TYPE_CHECKING
from trezor import io, log from trezor import io, log
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Awaitable, Callable, Coroutine, Generator from typing import Any, Awaitable, Callable, Coroutine, Generator, Union
Task = Coroutine | Generator Task = Union[Coroutine, Generator, "wait"]
AwaitableTask = Task | Awaitable AwaitableTask = Task | Awaitable
Finalizer = Callable[[Task, Any], None] Finalizer = Callable[[Task, Any], None]
@ -202,6 +202,13 @@ class Syscall:
pass pass
class Timeout(Exception):
pass
_TIMEOUT_ERROR = Timeout()
class sleep(Syscall): class sleep(Syscall):
"""Pause current task and resume it after given delay. """Pause current task and resume it after given delay.
@ -233,11 +240,41 @@ class wait(Syscall):
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event >>> event, x, y = await loop.wait(io.TOUCH) # await touch event
""" """
def __init__(self, msg_iface: int) -> None: _DO_NOT_RESCHEDULE = Syscall()
def __init__(self, msg_iface: int, timeout_ms: int | None = None) -> None:
self.msg_iface = msg_iface self.msg_iface = msg_iface
self.timeout_ms = timeout_ms
self.task: Task | None = None
def handle(self, task: Task) -> None: def handle(self, task: Task) -> None:
pause(task, self.msg_iface) self.task = task
pause(self, self.msg_iface)
if self.timeout_ms is not None:
deadline = utime.ticks_add(utime.ticks_ms(), self.timeout_ms)
schedule(self, _TIMEOUT_ERROR, deadline)
def send(self, __value: Any) -> Any:
assert self.task is not None
self.close()
_step(self.task, __value)
return self._DO_NOT_RESCHEDULE
throw = send
def close(self) -> None:
_queue.discard(self)
if self.msg_iface in _paused:
_paused[self.msg_iface].discard(self)
if not _paused[self.msg_iface]:
del _paused[self.msg_iface]
def __iter__(self) -> Generator:
try:
return (yield self)
finally:
# whichever way we got here, we must be removed from the paused list
self.close()
_type_gen: type[Generator] = type((lambda: (yield))()) _type_gen: type[Generator] = type((lambda: (yield))())