diff --git a/src/trezor/loop.py b/src/trezor/loop.py index c061283e07..ad94eded1c 100644 --- a/src/trezor/loop.py +++ b/src/trezor/loop.py @@ -6,62 +6,119 @@ from . import msg from . import log if __debug__: - # For performance stats + # for performance stats import array log_delay_pos = 0 log_delay_rb_len = const(10) log_delay_rb = array.array('i', [0] * log_delay_rb_len) -# Touch interface -TOUCH = const(256) # 0-255 is reserved for USB interfaces -TOUCH_START = const(1) -TOUCH_MOVE = const(2) -TOUCH_END = const(4) +paused_tasks = {} # {message interface: [task]} +schedule_counter = 0 +scheduled_tasks = [] # heap: [(time, counter, task, value)] +MAX_SELECT_DELAY = const(1000000) -msg_handlers = {} # Message interface -> [generator] -time_queue = [] -time_ticket = 0 +# message interfaces: +# 0-255 - USB HID +# 256 - touch event interface + +TOUCH = const(256) # interface +TOUCH_START = const(1) # event +TOUCH_MOVE = const(2) # event +TOUCH_END = const(4) # event -def schedule(gen, data=None, time=None): - global time_ticket - if not time: +def schedule_task(task, value=None, time=None): + global schedule_counter + if time is None: time = utime.ticks_us() - heappush(time_queue, (time, time_ticket, gen, data)) - time_ticket += 1 - return gen + heappush(scheduled_tasks, (time, schedule_counter, task, value)) + schedule_counter += 1 -def unschedule(gen): - global time_queue - time_queue = [entry for entry in time_queue if entry[1] is not gen] - heapify(time_queue) +def unschedule_task(task): + global scheduled_tasks + scheduled_tasks = [t for t in scheduled_tasks if t[1] is not task] + heapify(scheduled_tasks) -def block(gen, iface): - if iface in msg_handlers: - msg_handlers[iface].append(gen) +def pause_task(task, iface): + paused_tasks.setdefault(iface, []).append(task) + + +def unpause_task(task): + for iface in paused_tasks: + if task in paused_tasks[iface]: + paused_tasks[iface].remove(task) + + +def run_task(task, value): + try: + if isinstance(value, Exception): + result = task.throw(value) + else: + result = task.send(value) + except StopIteration as e: + log.debug(__name__, '%s finished', task) + except Exception as e: + log.exception(__name__, e) else: - msg_handlers[iface] = [gen] + if isinstance(result, Syscall): + result.handle(task) + elif result is None: + schedule_task(task) + else: + log.error(__name__, '%s is unknown syscall', result) -def unblock(gen): - for iface in msg_handlers: - if gen in msg_handlers[iface]: - msg_handlers[iface].remove(gen) +def handle_message(message): + if not paused_tasks: + return + iface, *value = message + tasks = paused_tasks.pop(iface, ()) + for task in tasks: + run_task(task, value) + + +def handle_timeout(): + if not scheduled_tasks: + return + _, _, task, value = heappop(scheduled_tasks) + run_task(task, value) + + +def run_forever(): + if __debug__: + global log_delay_pos + while True: + if scheduled_tasks: + t, _, _, _ = scheduled_tasks[0] + delay = t - utime.ticks_us() + else: + delay = MAX_SELECT_DELAY + if __debug__: + # add current delay to ring buffer for performance stats + log_delay_rb[log_delay_pos] = delay + log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len + message = msg.select(delay) + if message: + handle_message(message) + else: + handle_timeout() class Syscall(): - pass + + def __iter__(self): + return (yield self) class Sleep(Syscall): - def __init__(self, us): - self.time = utime.ticks_us() + us + def __init__(self, delay_us): + self.time = delay_us + utime.ticks_us() - def register(self, gen): - schedule(gen, self, self.time) + def handle(self, task): + schedule_task(task, self, self.time) class Select(Syscall): @@ -69,105 +126,71 @@ class Select(Syscall): def __init__(self, iface): self.iface = iface - def register(self, gen): - block(gen, self.iface) + def handle(self, task): + pause_task(task, self.iface) + + +NO_VALUE = () + + +class Future(Syscall): + + def __init__(self): + self.value = NO_VALUE + self.task = None + + def handle(self, task): + self.task = task + if self.value is not NO_VALUE: + self._deliver() + + def resolve(self, value): + if self.value is NO_VALUE: + self.value = value + if self.task is not None: + self._deliver() + + def _deliver(self): + schedule_task(self.task, self.value) class Wait(Syscall): - def __init__(self, gens, wait_for=1, exit_others=True): - self.gens = gens + def __init__(self, children, wait_for=1, exit_others=True): + self.children = children self.wait_for = wait_for self.exit_others = exit_others - self.scheduled = [] # In uPython, set() cannot contain generators + self.scheduled = [] self.finished = [] self.callback = None - def register(self, gen): - self.scheduled = [schedule(self._wait(g)) for g in self.gens] - self.callback = gen + def handle(self, task): + self.callback = task + self.scheduled = [self._wait(c) for c in self.children] + for ct in self.scheduled: + schedule_task(ct) def exit(self): - for gen in self.scheduled: - if gen not in self.finished and isinstance(gen, type_gen): - unschedule(gen) - unblock(gen) - gen.close() + for task in self.scheduled: + if task not in self.finished: + unschedule_task(task) + unpause_task(task) + task.close() - def _wait(self, gen): + def _wait(self, child): try: - if isinstance(gen, type_gen): - result = yield from gen + if isinstance(child, type_gen): + result = yield from child else: - result = yield gen - except Exception as exc: - self._finish(gen, exc) + result = yield child + except Exception as e: + self._finish(child, e) else: - self._finish(gen, result) + self._finish(child, result) - def _finish(self, gen, result): - self.finished.append(gen) + def _finish(self, child, result): + self.finished.append(child) if self.wait_for == len(self.finished) or isinstance(result, Exception): if self.exit_others: self.exit() - schedule(self.callback, result) - self.callback = None - - -def step_task(gen, data): - if isinstance(data, Exception): - result = gen.throw(data) - else: - result = gen.send(data) - if isinstance(result, Syscall): - result.register(gen) # Execute the syscall - elif result is None: - schedule(gen) # Just call us asap - else: - raise Exception('Unhandled result %s by %s' % (result, gen)) - - -def run_forever(): - if __debug__: - global log_delay_pos, log_delay_rb, log_delay_rb_len - - DELAY_MAX = const(1000000) - - while True: - - # Peek at how long we can sleep while waiting for an event - if time_queue: - t, _, _, _ = time_queue[0] - delay = t - utime.ticks_us() - else: - delay = DELAY_MAX - - if __debug__: - # Adding current delay to ring buffer for performance stats - log_delay_rb[log_delay_pos] = delay - log_delay_pos = (log_delay_pos + 1) % log_delay_rb_len - - m = msg.select(delay) - if m: - # Run interrupt handlers right away, they have priority - iface, *data = m - tasks = msg_handlers.pop(iface, None) - if not tasks: - log.info(__name__, 'No handler for message: %s', iface) - continue - else: - # Run something from the time queue - if time_queue: - _, _, gen, data = heappop(time_queue) - tasks = (gen,) - else: - continue - - # Run the tasks - for gen in tasks: - try: - step_task(gen, data) - except StopIteration as e: - log.debug(__name__, '%s finished', gen) - except Exception as e: - log.exception(__name__, e) + schedule_task(self.callback, result) diff --git a/src/trezor/main.py b/src/trezor/main.py index bcb70bf653..e9ac489b3b 100644 --- a/src/trezor/main.py +++ b/src/trezor/main.py @@ -12,7 +12,7 @@ log.level = log.INFO def perf_info_debug(): while True: - queue = [str(x[2]).split("'")[1] for x in loop.time_queue] + queue = [str(x[2]).split("'")[1] for x in loop.scheduled_tasks] delay_avg = sum(loop.log_delay_rb) / loop.log_delay_rb_len delay_last = loop.log_delay_rb[loop.log_delay_pos] @@ -29,13 +29,13 @@ def perf_info(): while True: gc.collect() log.info(__name__, "mem_alloc: %d", gc.mem_alloc()) - yield loop.sleep(1000000) + yield loop.Sleep(1000000) def run(main_layout): if __debug__: - loop.schedule(perf_info_debug()) + loop.schedule_task(perf_info_debug()) else: - loop.schedule(perf_info()) - loop.schedule(layout.set_main(main_layout)) + loop.schedule_task(perf_info()) + loop.schedule_task(layout.set_main(main_layout)) loop.run_forever()