1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 07:50:57 +00:00

core/loop: properly cleanup task waiting on a chan

This commit is contained in:
Jan Pochyla 2019-08-22 17:29:21 +02:00
parent 29cca614f0
commit 2c8b90f86e
9 changed files with 32 additions and 25 deletions

View File

@ -42,7 +42,7 @@ async def confirm(
) )
if __debug__: if __debug__:
return await ctx.wait(dialog, confirm_signal) is CONFIRMED return await ctx.wait(dialog, confirm_signal()) is CONFIRMED
else: else:
return await ctx.wait(dialog) is CONFIRMED return await ctx.wait(dialog) is CONFIRMED
@ -66,7 +66,7 @@ async def hold_to_confirm(
dialog = HoldToConfirm(content, confirm, confirm_style, loader_style) dialog = HoldToConfirm(content, confirm, confirm_style, loader_style)
if __debug__: if __debug__:
return await ctx.wait(dialog, confirm_signal) is CONFIRMED return await ctx.wait(dialog, confirm_signal()) is CONFIRMED
else: else:
return await ctx.wait(dialog) is CONFIRMED return await ctx.wait(dialog) is CONFIRMED

View File

@ -65,7 +65,7 @@ async def request_passphrase_ack(ctx: wire.Context, on_device: bool) -> str:
keyboard = PassphraseKeyboard("Enter passphrase", _MAX_PASSPHRASE_LEN) keyboard = PassphraseKeyboard("Enter passphrase", _MAX_PASSPHRASE_LEN)
if __debug__: if __debug__:
passphrase = await ctx.wait(keyboard, input_signal) passphrase = await ctx.wait(keyboard, input_signal())
else: else:
passphrase = await ctx.wait(keyboard) passphrase = await ctx.wait(keyboard)
if passphrase is CANCELLED: if passphrase is CANCELLED:

View File

@ -25,7 +25,7 @@ async def request_pin(
while True: while True:
if __debug__: if __debug__:
result = await loop.race(dialog, input_signal) result = await loop.race(dialog, input_signal())
else: else:
result = await dialog result = await dialog
if result is CANCELLED: if result is CANCELLED:

View File

@ -102,7 +102,7 @@ async def _confirm_word(ctx, share_index, numbered_share_words, count):
choices = [word for _, word in numbered_choices] choices = [word for _, word in numbered_choices]
select = MnemonicWordSelect(choices, share_index, checked_index, count) select = MnemonicWordSelect(choices, share_index, checked_index, count)
if __debug__: if __debug__:
selected_word = await ctx.wait(select, debug.input_signal) selected_word = await ctx.wait(select, debug.input_signal())
else: else:
selected_word = await ctx.wait(select) selected_word = await ctx.wait(select)

View File

@ -43,7 +43,7 @@ async def request_word_count(ctx: wire.Context, dry_run: bool) -> int:
text.normal("Number of words?") text.normal("Number of words?")
if __debug__: if __debug__:
count = await ctx.wait(WordSelector(text), input_signal) count = await ctx.wait(WordSelector(text), input_signal())
count = int(count) # if input_signal was triggered, count is a string count = int(count) # if input_signal was triggered, count is a string
else: else:
count = await ctx.wait(WordSelector(text)) count = await ctx.wait(WordSelector(text))
@ -63,7 +63,7 @@ async def request_mnemonic(
else: else:
keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, count)) keyboard = Bip39Keyboard("Type word %s of %s:" % (i + 1, count))
if __debug__: if __debug__:
word = await ctx.wait(keyboard, input_signal) word = await ctx.wait(keyboard, input_signal())
else: else:
word = await ctx.wait(keyboard) word = await ctx.wait(keyboard)
@ -145,7 +145,7 @@ async def show_keyboard_info(ctx: wire.Context) -> None:
"Great!", "Great!",
) )
if __debug__: if __debug__:
await ctx.wait(info, confirm_signal) await ctx.wait(info, confirm_signal())
else: else:
await ctx.wait(info) await ctx.wait(info)

View File

@ -30,7 +30,7 @@ async def naive_pagination(
while True: while True:
await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck)
if __debug__: if __debug__:
result = await loop.race(paginated, confirm_signal) result = await loop.race(paginated, confirm_signal())
else: else:
result = await paginated result = await paginated
if result is CONFIRMED: if result is CONFIRMED:

View File

@ -263,17 +263,14 @@ class race(Syscall):
>>> # async def wait_for_touch(): ... >>> # async def wait_for_touch(): ...
>>> # async def animate_logo(): ... >>> # async def animate_logo(): ...
>>> some_signal = loop.signal()
>>> touch_task = wait_for_touch() >>> touch_task = wait_for_touch()
>>> animation_task = animate_logo() >>> animation_task = animate_logo()
>>> racer = loop.race(some_signal, touch_task, animation_task) >>> racer = loop.race(touch_task, animation_task)
>>> result = await racer >>> result = await racer
>>> if animation_task in racer.finished: >>> if animation_task in racer.finished:
>>> print('animation task returned value:', result) >>> print('animation task returned value:', result)
>>> elif touch_task in racer.finished: >>> elif touch_task in racer.finished:
>>> print('touch task returned value:', result) >>> print('touch task returned value:', result)
>>> else:
>>> print('signal was triggered with value:', result)
Note: You should not directly `yield` a `race` instance, see logic in Note: You should not directly `yield` a `race` instance, see logic in
`race.__iter__` for explanation. Always use `await`. `race.__iter__` for explanation. Always use `await`.
@ -364,32 +361,42 @@ class chan:
""" """
class Put(Syscall): class Put(Syscall):
def __init__(self, ch: "chan") -> None: def __init__(self, ch: "chan", value: Any) -> None:
self.ch = ch self.ch = ch
self.value = None # type: Any
def __call__(self, value: Any) -> Syscall:
self.value = value self.value = value
return self self.task = None # type: Optional[Task]
def handle(self, task: Task) -> None: def handle(self, task: Task) -> None:
self.task = task
self.ch._schedule_put(task, self.value) self.ch._schedule_put(task, self.value)
class Take(Syscall): class Take(Syscall):
def __init__(self, ch: "chan") -> None: def __init__(self, ch: "chan") -> None:
self.ch = ch self.ch = ch
self.task = None # type: Optional[Task]
def __call__(self) -> Syscall:
return self
def handle(self, task) -> None: def handle(self, task) -> None:
self.task = task
self.ch._schedule_take(task) self.ch._schedule_take(task)
def __init__(self): def __init__(self):
self.putters = [] # type: List[Tuple[Optional[Task], Any]] self.putters = [] # type: List[Tuple[Optional[Task], Any]]
self.takers = [] # type: List[Task] self.takers = [] # type: List[Task]
self.put = chan.Put(self)
self.take = chan.Take(self) def put(self, value: Any) -> None:
put = chan.Put(self, value)
try:
yield put
except: # noqa: E722
self.putters.remove((put.task, value))
def take(self) -> None:
take = chan.Take(self)
try:
yield take
except: # noqa: E722
self.takers.remove(take.task)
raise
def publish(self, value: Any) -> None: def publish(self, value: Any) -> None:
if self.takers: if self.takers:

View File

@ -267,7 +267,7 @@ class Layout(Component):
# layout tasks to trigger restart by exiting (new tasks are created # layout tasks to trigger restart by exiting (new tasks are created
# and we continue, because we are in a loop). # and we continue, because we are in a loop).
while True: while True:
await loop.race(layout_chan.take, *self.create_tasks()) await loop.race(layout_chan.take(), *self.create_tasks())
except Result as result: except Result as result:
# Result exception was raised, this means this layout is complete. # Result exception was raised, this means this layout is complete.
value = result.value value = result.value

View File

@ -77,7 +77,7 @@ class Paginated(ui.Layout):
directions = SWIPE_VERTICAL directions = SWIPE_VERTICAL
if __debug__: if __debug__:
swipe = await loop.race(Swipe(directions), swipe_signal) swipe = await loop.race(Swipe(directions), swipe_signal())
else: else:
swipe = await Swipe(directions) swipe = await Swipe(directions)