diff --git a/core/src/apps/webauthn/fido2.py b/core/src/apps/webauthn/fido2.py index 4ac0a6ce89..056dc56f54 100644 --- a/core/src/apps/webauthn/fido2.py +++ b/core/src/apps/webauthn/fido2.py @@ -374,6 +374,7 @@ async def _read_cmd(iface: HID) -> Cmd | None: desc_cont = frame_cont() read = loop.wait(iface.iface_num() | io.POLL_READ) + # wait for incoming comand indefinitely buf = await read while True: ifrm = overlay_struct(bytearray(buf), desc_init) @@ -409,9 +410,12 @@ async def _read_cmd(iface: HID) -> Cmd | None: else: data = data[:bcnt] + # set a timeout for subsequent reads + read.timeout_ms = _CTAP_HID_TIMEOUT_MS while datalen < bcnt: - buf = await loop.race(read, loop.sleep(_CTAP_HID_TIMEOUT_MS)) - if not isinstance(buf, bytes): + try: + buf = await read + except loop.Timeout: if __debug__: warning(__name__, "_ERR_MSG_TIMEOUT") await send_cmd(cmd_error(ifrm_cid, _ERR_MSG_TIMEOUT), iface) @@ -494,7 +498,9 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None: if offset < datalen: frm = overlay_struct(buf, cont_desc) - write = loop.wait(iface.iface_num() | io.POLL_WRITE) + write = loop.wait( + iface.iface_num() | io.POLL_WRITE, timeout_ms=_CTAP_HID_TIMEOUT_MS + ) while offset < datalen: frm.seq = seq copied = utils.memcpy(frm.data, 0, cmd.data, offset, datalen) @@ -502,10 +508,7 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None: if copied < _FRAME_CONT_SIZE: frm.data[copied:] = bytearray(_FRAME_CONT_SIZE - copied) while True: - ret = await loop.race(write, loop.sleep(_CTAP_HID_TIMEOUT_MS)) - if ret is not None: - raise TimeoutError - + await write if iface.write(buf) > 0: break seq += 1 diff --git a/core/src/trezor/loop.py b/core/src/trezor/loop.py index f7edc6485c..636ee333cd 100644 --- a/core/src/trezor/loop.py +++ b/core/src/trezor/loop.py @@ -14,9 +14,9 @@ from typing import TYPE_CHECKING from trezor import io, log if TYPE_CHECKING: - from typing import Any, Awaitable, Callable, Coroutine, Generator + from typing import Any, Awaitable, Callable, Coroutine, Generator, Union - Task = Coroutine | Generator + Task = Union[Coroutine, Generator, "wait"] AwaitableTask = Task | Awaitable Finalizer = Callable[[Task, Any], None] @@ -202,6 +202,13 @@ class Syscall: pass +class Timeout(Exception): + pass + + +_TIMEOUT_ERROR = Timeout() + + class sleep(Syscall): """Pause current task and resume it after given delay. @@ -233,11 +240,41 @@ class wait(Syscall): >>> event, x, y = await loop.wait(io.TOUCH) # await touch event """ - def __init__(self, msg_iface: int) -> None: + _DO_NOT_RESCHEDULE = Syscall() + + def __init__(self, msg_iface: int, timeout_ms: int | None = None) -> None: self.msg_iface = msg_iface + self.timeout_ms = timeout_ms + self.task: Task | None = None def handle(self, task: Task) -> None: - pause(task, self.msg_iface) + self.task = task + pause(self, self.msg_iface) + if self.timeout_ms is not None: + deadline = utime.ticks_add(utime.ticks_ms(), self.timeout_ms) + schedule(self, _TIMEOUT_ERROR, deadline) + + def send(self, __value: Any) -> Any: + assert self.task is not None + self.close() + _step(self.task, __value) + return self._DO_NOT_RESCHEDULE + + throw = send + + def close(self) -> None: + _queue.discard(self) + if self.msg_iface in _paused: + _paused[self.msg_iface].discard(self) + if not _paused[self.msg_iface]: + del _paused[self.msg_iface] + + def __iter__(self) -> Generator: + try: + return (yield self) + finally: + # whichever way we got here, we must be removed from the paused list + self.close() _type_gen: type[Generator] = type((lambda: (yield))())