diff --git a/src/trezor/loop.py b/src/trezor/loop.py index c69f39aec1..e6c9939cba 100644 --- a/src/trezor/loop.py +++ b/src/trezor/loop.py @@ -211,7 +211,7 @@ class signal(Syscall): class spawn(Syscall): ''' - Execute one or more children tasks and wait until one or more of them exit. + Execute one or more children tasks and wait until one of them exits. Return value of `spawn` is the return value of task that triggered the completion. By default, `spawn` returns after the first child completes, and other running children are killed (by cancelling any pending schedules and @@ -234,9 +234,8 @@ class spawn(Syscall): `spawn.__iter__` for explanation. Always use `await`. ''' - def __init__(self, *children, wait_for=1, exit_others=True): + def __init__(self, *children, exit_others=True): self.children = children - self.wait_for = wait_for self.exit_others = exit_others self.scheduled = None # list of scheduled wrapper tasks self.finished = None # list of children that finished @@ -245,28 +244,31 @@ class spawn(Syscall): def handle(self, task): self.callback = task self.finished = [] - self.scheduled = [self._wait(c) for c in self.children] - for ct in self.scheduled: - schedule(ct) + self.scheduled = [] + for index, child in enumerate(self.children): + parent = self._wait(child, index) + schedule(parent) + self.scheduled.append(parent) - def exit(self): - for ct in self.scheduled: - close(ct) + def exit(self, skip_index=-1): + for index, parent in enumerate(self.scheduled): + if index != skip_index: + close(parent) - async def _wait(self, child): + async def _wait(self, child, index): try: result = await child except Exception as e: - self._finish(child, e) + self._finish(child, index, e) else: - self._finish(child, result) + self._finish(child, index, result) - def _finish(self, child, result): - self.finished.append(child) - if self.wait_for == len(self.finished) or isinstance(result, Exception): - schedule(self.callback, result) + def _finish(self, child, index, result): + if not self.finished: + self.finished.append(child) if self.exit_others: - self.exit() + self.exit(index) + schedule(self.callback, result) def __iter__(self): try: