From 748d516ac4d96d6e87577dfe877d7232f4403022 Mon Sep 17 00:00:00 2001 From: matejcik Date: Thu, 29 Jun 2023 17:08:13 +0200 Subject: [PATCH] feat(core): introduce timeout to loop.wait() --- core/src/apps/webauthn/fido2.py | 16 ++++++------ core/src/trezor/loop.py | 44 ++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/core/src/apps/webauthn/fido2.py b/core/src/apps/webauthn/fido2.py index 724636ea8..81c4cfdcc 100644 --- a/core/src/apps/webauthn/fido2.py +++ b/core/src/apps/webauthn/fido2.py @@ -371,7 +371,7 @@ class Cmd: async def _read_cmd(iface: HID) -> Cmd | None: desc_init = frame_init() desc_cont = frame_cont() - read = loop.wait(iface.iface_num() | io.POLL_READ) + read = loop.wait(iface.iface_num() | io.POLL_READ, timeout_ms=_CTAP_HID_TIMEOUT_MS) buf = await read while True: @@ -409,8 +409,9 @@ async def _read_cmd(iface: HID) -> Cmd | None: data = data[:bcnt] while datalen < bcnt: - buf = await loop.race(read, loop.sleep(_CTAP_HID_TIMEOUT_MS)) - if not isinstance(buf, bytes): + try: + buf = await read + except loop.Timeout: if __debug__: warning(__name__, "_ERR_MSG_TIMEOUT") await send_cmd(cmd_error(ifrm_cid, _ERR_MSG_TIMEOUT), iface) @@ -493,7 +494,9 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None: if offset < datalen: frm = overlay_struct(buf, cont_desc) - write = loop.wait(iface.iface_num() | io.POLL_WRITE) + write = loop.wait( + iface.iface_num() | io.POLL_WRITE, timeout_ms=_CTAP_HID_TIMEOUT_MS + ) while offset < datalen: frm.seq = seq copied = utils.memcpy(frm.data, 0, cmd.data, offset, datalen) @@ -501,10 +504,7 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None: if copied < _FRAME_CONT_SIZE: frm.data[copied:] = bytearray(_FRAME_CONT_SIZE - copied) while True: - ret = await loop.race(write, loop.sleep(_CTAP_HID_TIMEOUT_MS)) - if ret is not None: - raise TimeoutError - + await write if iface.write(buf) > 0: break seq += 1 diff --git a/core/src/trezor/loop.py b/core/src/trezor/loop.py index 1595a5c44..698987e32 100644 --- a/core/src/trezor/loop.py +++ b/core/src/trezor/loop.py @@ -20,9 +20,10 @@ if TYPE_CHECKING: Callable, Coroutine, Generator, + Union, ) - Task = Coroutine | Generator + Task = Union[Coroutine, Generator, "wait"] AwaitableTask = Task | Awaitable Finalizer = Callable[[Task, Any], None] @@ -205,6 +206,13 @@ class Syscall: pass +class Timeout(Exception): + pass + + +_TIMEOUT_ERROR = Timeout() + + class sleep(Syscall): """Pause current task and resume it after given delay. @@ -236,11 +244,41 @@ class wait(Syscall): >>> event, x, y = await loop.wait(io.TOUCH) # await touch event """ - def __init__(self, msg_iface: int) -> None: + _DO_NOT_RESCHEDULE = Syscall() + + def __init__(self, msg_iface: int, timeout_ms: int | None = None) -> None: self.msg_iface = msg_iface + self.timeout_ms = timeout_ms + self.task: Task | None = None def handle(self, task: Task) -> None: - pause(task, self.msg_iface) + self.task = task + pause(self, self.msg_iface) + if self.timeout_ms is not None: + deadline = utime.ticks_add(utime.ticks_ms(), self.timeout_ms) + schedule(self, _TIMEOUT_ERROR, deadline) + + def send(self, __value: Any) -> Any: + assert self.task is not None + _paused[self.msg_iface].discard(self) + _queue.discard(self) + _step(self.task, __value) + return self._DO_NOT_RESCHEDULE + + throw = send + + def close(self) -> None: + pass + + def __iter__(self) -> Generator: + try: + return (yield self) + except BaseException: + # exception was raised on the waiting task externally with + # close() or throw(), kill the children tasks and re-raise + _queue.discard(self) + _paused[self.msg_iface].discard(self) + raise _type_gen: type[Generator] = type((lambda: (yield))())