mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 06:48:16 +00:00
core/loop: properly cleanup task waiting on a chan
This commit is contained in:
parent
29cca614f0
commit
2c8b90f86e
@ -42,7 +42,7 @@ async def confirm(
|
||||
)
|
||||
|
||||
if __debug__:
|
||||
return await ctx.wait(dialog, confirm_signal) is CONFIRMED
|
||||
return await ctx.wait(dialog, confirm_signal()) is CONFIRMED
|
||||
else:
|
||||
return await ctx.wait(dialog) is CONFIRMED
|
||||
|
||||
@ -66,7 +66,7 @@ async def hold_to_confirm(
|
||||
dialog = HoldToConfirm(content, confirm, confirm_style, loader_style)
|
||||
|
||||
if __debug__:
|
||||
return await ctx.wait(dialog, confirm_signal) is CONFIRMED
|
||||
return await ctx.wait(dialog, confirm_signal()) is CONFIRMED
|
||||
else:
|
||||
return await ctx.wait(dialog) is CONFIRMED
|
||||
|
||||
|
@ -65,7 +65,7 @@ async def request_passphrase_ack(ctx: wire.Context, on_device: bool) -> str:
|
||||
|
||||
keyboard = PassphraseKeyboard("Enter passphrase", _MAX_PASSPHRASE_LEN)
|
||||
if __debug__:
|
||||
passphrase = await ctx.wait(keyboard, input_signal)
|
||||
passphrase = await ctx.wait(keyboard, input_signal())
|
||||
else:
|
||||
passphrase = await ctx.wait(keyboard)
|
||||
if passphrase is CANCELLED:
|
||||
|
@ -25,7 +25,7 @@ async def request_pin(
|
||||
|
||||
while True:
|
||||
if __debug__:
|
||||
result = await loop.race(dialog, input_signal)
|
||||
result = await loop.race(dialog, input_signal())
|
||||
else:
|
||||
result = await dialog
|
||||
if result is CANCELLED:
|
||||
|
@ -102,7 +102,7 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count):
|
||||
choices = [word for _, word in numbered_choices]
|
||||
select = MnemonicWordSelect(choices, share_index, checked_index, count)
|
||||
if __debug__:
|
||||
selected_word = await ctx.wait(select, debug.input_signal)
|
||||
selected_word = await ctx.wait(select, debug.input_signal())
|
||||
else:
|
||||
selected_word = await ctx.wait(select)
|
||||
|
||||
|
@ -43,7 +43,7 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int:
|
||||
text.normal("Number of words?")
|
||||
|
||||
if __debug__:
|
||||
count = await ctx.wait(WordSelector(text), input_signal)
|
||||
count = await ctx.wait(WordSelector(text), input_signal())
|
||||
count = int(count) # if input_signal was triggered, count is a string
|
||||
else:
|
||||
count = await ctx.wait(WordSelector(text))
|
||||
@ -63,7 +63,7 @@ async def request_mnemonic(
|
||||
else:
|
||||
keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, count))
|
||||
if __debug__:
|
||||
word = await ctx.wait(keyboard, input_signal)
|
||||
word = await ctx.wait(keyboard, input_signal())
|
||||
else:
|
||||
word = await ctx.wait(keyboard)
|
||||
|
||||
@ -145,7 +145,7 @@ async def show_keyboard_info(ctx: wire.Context) -> None:
|
||||
"Great!",
|
||||
)
|
||||
if __debug__:
|
||||
await ctx.wait(info, confirm_signal)
|
||||
await ctx.wait(info, confirm_signal())
|
||||
else:
|
||||
await ctx.wait(info)
|
||||
|
||||
|
@ -30,7 +30,7 @@ async def naive_pagination(
|
||||
while True:
|
||||
await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck)
|
||||
if __debug__:
|
||||
result = await loop.race(paginated, confirm_signal)
|
||||
result = await loop.race(paginated, confirm_signal())
|
||||
else:
|
||||
result = await paginated
|
||||
if result is CONFIRMED:
|
||||
|
@ -263,17 +263,14 @@ class race(Syscall):
|
||||
|
||||
>>> # async def wait_for_touch(): ...
|
||||
>>> # async def animate_logo(): ...
|
||||
>>> some_signal = loop.signal()
|
||||
>>> touch_task = wait_for_touch()
|
||||
>>> animation_task = animate_logo()
|
||||
>>> racer = loop.race(some_signal, touch_task, animation_task)
|
||||
>>> racer = loop.race(touch_task, animation_task)
|
||||
>>> result = await racer
|
||||
>>> if animation_task in racer.finished:
|
||||
>>> print('animation task returned value:', result)
|
||||
>>> elif touch_task in racer.finished:
|
||||
>>> print('touch task returned value:', result)
|
||||
>>> else:
|
||||
>>> print('signal was triggered with value:', result)
|
||||
|
||||
Note: You should not directly `yield` a `race` instance, see logic in
|
||||
`race.__iter__` for explanation. Always use `await`.
|
||||
@ -364,32 +361,42 @@ class chan:
|
||||
"""
|
||||
|
||||
class Put(Syscall):
|
||||
def __init__(self, ch: "chan") -> None:
|
||||
def __init__(self, ch: "chan", value: Any) -> None:
|
||||
self.ch = ch
|
||||
self.value = None # type: Any
|
||||
|
||||
def __call__(self, value: Any) -> Syscall:
|
||||
self.value = value
|
||||
return self
|
||||
self.task = None # type: Optional[Task]
|
||||
|
||||
def handle(self, task: Task) -> None:
|
||||
self.task = task
|
||||
self.ch._schedule_put(task, self.value)
|
||||
|
||||
class Take(Syscall):
|
||||
def __init__(self, ch: "chan") -> None:
|
||||
self.ch = ch
|
||||
|
||||
def __call__(self) -> Syscall:
|
||||
return self
|
||||
self.task = None # type: Optional[Task]
|
||||
|
||||
def handle(self, task) -> None:
|
||||
self.task = task
|
||||
self.ch._schedule_take(task)
|
||||
|
||||
def __init__(self):
|
||||
self.putters = [] # type: List[Tuple[Optional[Task], Any]]
|
||||
self.takers = [] # type: List[Task]
|
||||
self.put = chan.Put(self)
|
||||
self.take = chan.Take(self)
|
||||
|
||||
def put(self, value: Any) -> None:
|
||||
put = chan.Put(self, value)
|
||||
try:
|
||||
yield put
|
||||
except: # noqa: E722
|
||||
self.putters.remove((put.task, value))
|
||||
|
||||
def take(self) -> None:
|
||||
take = chan.Take(self)
|
||||
try:
|
||||
yield take
|
||||
except: # noqa: E722
|
||||
self.takers.remove(take.task)
|
||||
raise
|
||||
|
||||
def publish(self, value: Any) -> None:
|
||||
if self.takers:
|
||||
|
@ -267,7 +267,7 @@ class Layout(Component):
|
||||
# layout tasks to trigger restart by exiting (new tasks are created
|
||||
# and we continue, because we are in a loop).
|
||||
while True:
|
||||
await loop.race(layout_chan.take, *self.create_tasks())
|
||||
await loop.race(layout_chan.take(), *self.create_tasks())
|
||||
except Result as result:
|
||||
# Result exception was raised, this means this layout is complete.
|
||||
value = result.value
|
||||
|
@ -77,7 +77,7 @@ class Paginated(ui.Layout):
|
||||
directions = SWIPE_VERTICAL
|
||||
|
||||
if __debug__:
|
||||
swipe = await loop.race(Swipe(directions), swipe_signal)
|
||||
swipe = await loop.race(Swipe(directions), swipe_signal())
|
||||
else:
|
||||
swipe = await Swipe(directions)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user