mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-02 04:18:20 +00:00
feat(core): introduce timeout to loop.wait()
This commit is contained in:
parent
0f84d51051
commit
5686f46a03
@ -374,6 +374,7 @@ async def _read_cmd(iface: HID) -> Cmd | None:
|
|||||||
desc_cont = frame_cont()
|
desc_cont = frame_cont()
|
||||||
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
read = loop.wait(iface.iface_num() | io.POLL_READ)
|
||||||
|
|
||||||
|
# wait for incoming comand indefinitely
|
||||||
buf = await read
|
buf = await read
|
||||||
while True:
|
while True:
|
||||||
ifrm = overlay_struct(bytearray(buf), desc_init)
|
ifrm = overlay_struct(bytearray(buf), desc_init)
|
||||||
@ -409,9 +410,12 @@ async def _read_cmd(iface: HID) -> Cmd | None:
|
|||||||
else:
|
else:
|
||||||
data = data[:bcnt]
|
data = data[:bcnt]
|
||||||
|
|
||||||
|
# set a timeout for subsequent reads
|
||||||
|
read.timeout_ms = _CTAP_HID_TIMEOUT_MS
|
||||||
while datalen < bcnt:
|
while datalen < bcnt:
|
||||||
buf = await loop.race(read, loop.sleep(_CTAP_HID_TIMEOUT_MS))
|
try:
|
||||||
if not isinstance(buf, bytes):
|
buf = await read
|
||||||
|
except loop.Timeout:
|
||||||
if __debug__:
|
if __debug__:
|
||||||
warning(__name__, "_ERR_MSG_TIMEOUT")
|
warning(__name__, "_ERR_MSG_TIMEOUT")
|
||||||
await send_cmd(cmd_error(ifrm_cid, _ERR_MSG_TIMEOUT), iface)
|
await send_cmd(cmd_error(ifrm_cid, _ERR_MSG_TIMEOUT), iface)
|
||||||
@ -494,7 +498,9 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None:
|
|||||||
if offset < datalen:
|
if offset < datalen:
|
||||||
frm = overlay_struct(buf, cont_desc)
|
frm = overlay_struct(buf, cont_desc)
|
||||||
|
|
||||||
write = loop.wait(iface.iface_num() | io.POLL_WRITE)
|
write = loop.wait(
|
||||||
|
iface.iface_num() | io.POLL_WRITE, timeout_ms=_CTAP_HID_TIMEOUT_MS
|
||||||
|
)
|
||||||
while offset < datalen:
|
while offset < datalen:
|
||||||
frm.seq = seq
|
frm.seq = seq
|
||||||
copied = utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
copied = utils.memcpy(frm.data, 0, cmd.data, offset, datalen)
|
||||||
@ -502,10 +508,7 @@ async def send_cmd(cmd: Cmd, iface: HID) -> None:
|
|||||||
if copied < _FRAME_CONT_SIZE:
|
if copied < _FRAME_CONT_SIZE:
|
||||||
frm.data[copied:] = bytearray(_FRAME_CONT_SIZE - copied)
|
frm.data[copied:] = bytearray(_FRAME_CONT_SIZE - copied)
|
||||||
while True:
|
while True:
|
||||||
ret = await loop.race(write, loop.sleep(_CTAP_HID_TIMEOUT_MS))
|
await write
|
||||||
if ret is not None:
|
|
||||||
raise TimeoutError
|
|
||||||
|
|
||||||
if iface.write(buf) > 0:
|
if iface.write(buf) > 0:
|
||||||
break
|
break
|
||||||
seq += 1
|
seq += 1
|
||||||
|
@ -14,9 +14,9 @@ from typing import TYPE_CHECKING
|
|||||||
from trezor import io, log
|
from trezor import io, log
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any, Awaitable, Callable, Coroutine, Generator
|
from typing import Any, Awaitable, Callable, Coroutine, Generator, Union
|
||||||
|
|
||||||
Task = Coroutine | Generator
|
Task = Union[Coroutine, Generator, "wait"]
|
||||||
AwaitableTask = Task | Awaitable
|
AwaitableTask = Task | Awaitable
|
||||||
Finalizer = Callable[[Task, Any], None]
|
Finalizer = Callable[[Task, Any], None]
|
||||||
|
|
||||||
@ -202,6 +202,13 @@ class Syscall:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Timeout(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_TIMEOUT_ERROR = Timeout()
|
||||||
|
|
||||||
|
|
||||||
class sleep(Syscall):
|
class sleep(Syscall):
|
||||||
"""Pause current task and resume it after given delay.
|
"""Pause current task and resume it after given delay.
|
||||||
|
|
||||||
@ -233,11 +240,41 @@ class wait(Syscall):
|
|||||||
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event
|
>>> event, x, y = await loop.wait(io.TOUCH) # await touch event
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, msg_iface: int) -> None:
|
_DO_NOT_RESCHEDULE = Syscall()
|
||||||
|
|
||||||
|
def __init__(self, msg_iface: int, timeout_ms: int | None = None) -> None:
|
||||||
self.msg_iface = msg_iface
|
self.msg_iface = msg_iface
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.task: Task | None = None
|
||||||
|
|
||||||
def handle(self, task: Task) -> None:
|
def handle(self, task: Task) -> None:
|
||||||
pause(task, self.msg_iface)
|
self.task = task
|
||||||
|
pause(self, self.msg_iface)
|
||||||
|
if self.timeout_ms is not None:
|
||||||
|
deadline = utime.ticks_add(utime.ticks_ms(), self.timeout_ms)
|
||||||
|
schedule(self, _TIMEOUT_ERROR, deadline)
|
||||||
|
|
||||||
|
def send(self, __value: Any) -> Any:
|
||||||
|
assert self.task is not None
|
||||||
|
self.close()
|
||||||
|
_step(self.task, __value)
|
||||||
|
return self._DO_NOT_RESCHEDULE
|
||||||
|
|
||||||
|
throw = send
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
_queue.discard(self)
|
||||||
|
if self.msg_iface in _paused:
|
||||||
|
_paused[self.msg_iface].discard(self)
|
||||||
|
if not _paused[self.msg_iface]:
|
||||||
|
del _paused[self.msg_iface]
|
||||||
|
|
||||||
|
def __iter__(self) -> Generator:
|
||||||
|
try:
|
||||||
|
return (yield self)
|
||||||
|
finally:
|
||||||
|
# whichever way we got here, we must be removed from the paused list
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
_type_gen: type[Generator] = type((lambda: (yield))())
|
_type_gen: type[Generator] = type((lambda: (yield))())
|
||||||
|
Loading…
Reference in New Issue
Block a user