diff --git a/src/apps/homescreen/__init__.py b/src/apps/homescreen/__init__.py index ea11a22c7..08af2c8aa 100644 --- a/src/apps/homescreen/__init__.py +++ b/src/apps/homescreen/__init__.py @@ -1,17 +1,29 @@ -from trezor.dispatcher import register -from trezor.utils import unimport_func +from trezor.wire import register_type, protobuf_handler, write_message +from trezor.utils import unimport +from trezor.messages.wire_types import Initialize -@unimport_func -def dispatch_Initialize(mtype, mbuf): - from trezor.messages.Initialize import Initialize - - message = Initialize.loads(mbuf) - - from .layout_homescreen import layout_homescreen - return layout_homescreen(message) +@unimport +async def dispatch_Initialize(_, session_id): + from trezor.messages.Features import Features + features = Features( + revision='deadbeef', + bootloader_hash='deadbeef', + device_id='DEADBEEF', + coins=[], + imported=False, + initialized=False, + label='My TREZOR', + major_version=2, + minor_version=0, + patch_version=0, + pin_cached=False, + pin_protection=True, + passphrase_cached=False, + passphrase_protection=False, + vendor='bitcointrezor.com') + await write_message(session_id, features) def boot(): - Initialize = 0 - register(Initialize, dispatch_Initialize) + register_type(Initialize, protobuf_handler, dispatch_Initialize) diff --git a/src/apps/homescreen/layout_homescreen.py b/src/apps/homescreen/layout_homescreen.py index 624e9fdb6..8a064a38b 100644 --- a/src/apps/homescreen/layout_homescreen.py +++ b/src/apps/homescreen/layout_homescreen.py @@ -1,6 +1,5 @@ -from trezor import ui, dispatcher, loop, res, wire +from trezor import ui, loop, res from trezor.ui.swipe import Swipe -from trezor.utils import unimport_gen async def swipe_to_rotate(): @@ -10,37 +9,14 @@ async def swipe_to_rotate(): async def animate_logo(): - # def func(foreground): - # ui.display.icon(0, 0, res.load( - # 'apps/homescreen/res/trezor.toig'), foreground, ui.BLACK) - # await ui.animate_pulse(func, ui.WHITE, ui.GREY, speed=400000) - + icon = res.load('apps/homescreen/res/trezor.toig') async for fg in ui.pulse_animation(ui.WHITE, ui.GREY, speed=400000): - icon = res.load('apps/homescreen/res/trezor.toig') ui.display.icon(0, 0, icon, fg, ui.BLACK) -@unimport_gen -async def layout_homescreen(initialize_msg=None): - if initialize_msg is not None: - from trezor.messages.Features import Features - features = Features() - features.revision = 'deadbeef' - features.bootloader_hash = 'deadbeef' - features.device_id = 'DEADBEEF' - features.coins = [] - features.imported = False - features.initialized = False - features.label = 'My TREZOR' - features.major_version = 2 - features.minor_version = 0 - features.patch_version = 0 - features.pin_cached = False - features.pin_protection = True - features.passphrase_cached = False - features.passphrase_protection = False - features.vendor = 'bitcointrezor.com' - await wire.write(features) - await loop.Wait([dispatcher.dispatch(), - swipe_to_rotate(), - animate_logo()]) +async def layout_homescreen(): + wait = loop.Wait([swipe_to_rotate(), animate_logo()]) + try: + await wait + finally: + wait.exit() diff --git a/src/main.py b/src/main.py index 83bf0bdd0..8fc0129b4 100644 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,7 @@ import trezor.main from trezor import msg +from trezor import ui +from trezor import wire # Load all applications from apps import playground @@ -14,13 +16,16 @@ management.boot() wallet.boot() # Change backlight to white for better visibility -trezor.ui.display.backlight(255) +ui.display.backlight(255) # Just a demo to show how to register USB ifaces msg.setup([(1, 0xF53C), (2, 0xF1D0)]) +# Initialize the wire codec pipeline +wire.setup() + # Load default homescreen from apps.homescreen.layout_homescreen import layout_homescreen # Run main even loop and specify, which screen is default -trezor.main.run(main_layout=layout_homescreen) +trezor.main.run(default_workflow=layout_homescreen) diff --git a/src/trezor/dispatcher.py b/src/trezor/dispatcher.py deleted file mode 100644 index 8ab2e2c68..000000000 --- a/src/trezor/dispatcher.py +++ /dev/null @@ -1,21 +0,0 @@ -from . import wire -from . import layout - - -message_handlers = {} - - -def register(mtype, handler): - if mtype in message_handlers: - raise Exception('Message wire type %s is already registered', mtype) - message_handlers[mtype] = handler - - -def unregister(mtype): - del message_handlers[mtype] - - -def dispatch(): - _, mtype, mbuf = yield from wire.read_wire_msg() - handler = message_handlers[mtype] - layout.change(handler(mtype, mbuf)) diff --git a/src/trezor/layout.py b/src/trezor/layout.py deleted file mode 100644 index e54922504..000000000 --- a/src/trezor/layout.py +++ /dev/null @@ -1,33 +0,0 @@ -import utime - -from . import log -from . import utils - - -class ChangeLayoutException(Exception): - - def __init__(self, layout): - self.layout = layout - - -def change(layout): - raise ChangeLayoutException(layout) - - -def set_main(main_layout): - layout = main_layout() - - while True: - try: - layout = yield from layout - except ChangeLayoutException as e: - layout = e.layout - except Exception as e: - log.exception(__name__, e) - utime.sleep(1) # Don't produce wall of exceptions - - if not isinstance(layout, utils.type_gen): - log.info(__name__, 'Switching to main layout %s', main_layout) - layout = main_layout() - else: - log.info(__name__, 'Switching to proposed layout %s', layout) diff --git a/src/trezor/main.py b/src/trezor/main.py index e9ac489b3..c2a30743c 100644 --- a/src/trezor/main.py +++ b/src/trezor/main.py @@ -4,7 +4,7 @@ sys.path.append('lib') import gc from trezor import loop -from trezor import layout +from trezor import workflow from trezor import log log.level = log.INFO @@ -20,7 +20,7 @@ def perf_info_debug(): mem_alloc = gc.mem_alloc() gc.collect() log.info(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue: %s", - mem_alloc, gc.mem_alloc(), delay_avg, delay_last, ', '.join(queue)) + mem_alloc, gc.mem_alloc(), delay_avg, delay_last, ', '.join(queue)) yield loop.Sleep(1000000) @@ -32,10 +32,10 @@ def perf_info(): yield loop.Sleep(1000000) -def run(main_layout): +def run(default_workflow): if __debug__: loop.schedule_task(perf_info_debug()) else: loop.schedule_task(perf_info()) - loop.schedule_task(layout.set_main(main_layout)) + workflow.start_default(default_workflow) loop.run_forever() diff --git a/src/trezor/messages/__init__.py b/src/trezor/messages/__init__.py index e69de29bb..10d61a35b 100644 --- a/src/trezor/messages/__init__.py +++ b/src/trezor/messages/__init__.py @@ -0,0 +1,13 @@ +from . import wire_types + + +def get_protobuf_type_name(wire_type): + for name in dir(wire_types): + if getattr(wire_types, name) == wire_type: + return name + + +def get_protobuf_type(wire_type): + name = get_protobuf_type_name(wire_type) + module = __import__('.%s' % name, globals(), locals(), (name,), 1) + return getattr(module, name) diff --git a/src/trezor/wire.py b/src/trezor/wire.py deleted file mode 100644 index 98e87a323..000000000 --- a/src/trezor/wire.py +++ /dev/null @@ -1,131 +0,0 @@ -import ustruct -import ubinascii -from . import msg -from . import loop -from . import log - -IFACE = const(0) - -# TREZOR wire protocol v2: -# -# HID report = 64 bytes, padded with 0x0 -# First report = !SSSSTTTTLLLLD... -# Next reports = #SSSSD...CCCC -# -# S = session id -# T = message type -# L = data length -# D = data -# C = data checksum - crc32 - -_REPORT_LEN = const(64) -_MAX_DATA_LEN = const(65536) -_HEADER_MAGIC = const(33) # ord('!') -_DATA_MAGIC = const(35) # ord('#') - - -def _read_report(): - rep, = yield loop.Select(IFACE) - assert len(rep) == _REPORT_LEN, 'HID read failed' - return memoryview(rep) - - -def _write_report(rep): - size = msg.send(IFACE, rep) - assert size == _REPORT_LEN, 'HID write failed' - yield # just to be a generator - - -def read_wire_msg(): - - rep = yield from _read_report() - magic, sid, mtype, mlen = ustruct.unpack('>BLLL', rep) - assert magic == _HEADER_MAGIC, 'Incorrect report magic' - assert mlen < _MAX_DATA_LEN, 'Message too large to read' - - mlen += 4 # Account for the checksum - data = rep[13:][:mlen] # Skip magic and header, trim to data len - remaining = mlen - len(data) - # Avoid the copy if we don't append - buffered = bytearray(data) if remaining > 0 else data - - while remaining > 0: - rep = yield from _read_report() - magic, rsid = ustruct.unpack('>BL', rep) - assert magic == _DATA_MAGIC, 'Incorrect report magic' - assert rsid == sid, 'Session ID mismatch' - - data = rep[5:][:remaining] # Skip magic and session ID, trim - buffered.extend(data) - remaining -= len(data) - - # Split to data and checksum - mbuf = buffered[:-4] - csum = ustruct.unpack_from('>L', buffered, -4) - - # Compare the checksums - if hasattr(ubinascii, 'crc32'): - assert csum == ubinascii.crc32(mbuf), 'Message checksum mismatch' - - return sid, mtype, mbuf - - -def write_wire_msg(sid, mtype, mbuf): - - rep = bytearray(_REPORT_LEN) - ustruct.pack_into('>BLLL', rep, 0, _HEADER_MAGIC, sid, mtype, len(mbuf)) - - rep = memoryview(rep) - mbuf = memoryview(mbuf) - data = rep[13:] # Skip magic and header - - if hasattr(ubinascii, 'crc32'): - csum = ubinascii.crc32(mbuf) - else: - csum = 0 - footer = ustruct.pack('>L', csum) - - while True: - n = min(len(data), len(mbuf)) - data[:n] = mbuf[:n] # Copy as much data as possible from mbuf to data - mbuf = mbuf[n:] # Skip written bytes - data = data[n:] # Skip written bytes - - # Continue with the footer if mbuf is empty and we have space - if not mbuf and footer and data: - mbuf = footer - footer = None - continue - - yield from _write_report(rep) - if not mbuf: - break - - # Reset to skip the magic and session ID - rep[0] = _DATA_MAGIC - data = rep[5:] - - -def read(*types): - if __debug__: - log.debug(__name__, 'Reading one of %s', types) - _, mtype, mbuf = yield from read_wire_msg() - for t in types: - if t.wire_type == mtype: - return t.loads(mbuf) - else: - raise Exception('Unexpected message') - - -def write(m): - if __debug__: - log.debug(__name__, 'Writing %s', m) - mbuf = m.dumps() - mtype = m.message_type.wire_type - yield from write_wire_msg(0, mtype, mbuf) - - -def call(req, *types): - yield from write(req) - res = yield from read(*types) - return res diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py new file mode 100644 index 000000000..ed9d57544 --- /dev/null +++ b/src/trezor/wire/__init__.py @@ -0,0 +1,155 @@ +from protobuf.protobuf import build_protobuf_message + +from trezor.loop import schedule_task, Future +from trezor.crypto import random +from trezor.messages import get_protobuf_type +from trezor.workflow import start_workflow +from trezor import log + +from .wire_io import read_report_stream, write_report_stream +from .wire_dispatcher import dispatch_reports_by_session +from .wire_codec import \ + decode_wire_stream, encode_wire_message, \ + encode_session_open_message, encode_session_close_message + +_session_handlers = {} # session id -> generator +_workflow_genfuncs = {} # wire type -> (generator function, args) +_opened_sessions = set() # session ids + + +def generate_session_id(): + while True: + session_id = random.uniform(0x0fffffff) + 1 + if session_id not in _opened_sessions: + return session_id + + +def open_session(): + session_id = generate_session_id() + _opened_sessions.add(session_id) + log.info(__name__, 'opened session %d: %s', session_id, _opened_sessions) + return session_id + + +def close_session(session_id): + _opened_sessions.discard(session_id) + _session_handlers.pop(session_id, None) + log.info(__name__, 'closed session %d: %s', session_id, _opened_sessions) + + +def register_type(wire_type, genfunc, *args): + if wire_type in _workflow_genfuncs: + raise KeyError('message of type %d already registered' % wire_type) + log.info(__name__, 'registering %s for type %d', + (genfunc, args), wire_type) + _workflow_genfuncs[wire_type] = (genfunc, args) + + +def register_session(session_id, handler): + if session_id not in _opened_sessions: + raise KeyError('session %d is unknown' % session_id) + if session_id in _session_handlers: + raise KeyError('session %d is already registered' % session_id) + log.info(__name__, 'registering %s for session %d', handler, session_id) + _session_handlers[session_id] = handler + + +def setup(): + report_writer = write_report_stream() + report_writer.send(None) + + open_session_handler = _handle_open_session(report_writer) + open_session_handler.send(None) + + close_session_handler = _handle_close_session(report_writer) + close_session_handler.send(None) + + fallback_session_handler = _handle_unknown_session() + fallback_session_handler.send(None) + + session_dispatcher = dispatch_reports_by_session( + _session_handlers, + open_session_handler, + close_session_handler, + fallback_session_handler) + session_dispatcher.send(None) + + schedule_task(read_report_stream(session_dispatcher)) + + +async def read_message(session_id, *exp_types): + future = Future() + wire_decoder = decode_wire_stream( + _dispatch_and_build_protobuf, session_id, exp_types, future) + wire_decoder.send(None) + register_session(session_id, wire_decoder) + return await future + + +async def write_message(session_id, pbuf_message): + msg_data = await pbuf_message.dumps() + msg_type = pbuf_message.message_type.wire_type + writer = write_report_stream() + writer.send(None) + encode_wire_message(msg_type, msg_data, session_id, writer) + + +def protobuf_handler(msg_type, data_len, session_id, callback, *args): + def finalizer(message): + start_workflow(callback(message, session_id, *args)) + pbuf_type = get_protobuf_type(msg_type) + builder = build_protobuf_message(pbuf_type, finalizer) + builder.send(None) + return pbuf_type.load(builder) + + +def _handle_open_session(write_target): + while True: + yield + session_id = open_session() + wire_decoder = decode_wire_stream(_handle_registered_type, session_id) + wire_decoder.send(None) + register_session(session_id, wire_decoder) + encode_session_open_message(session_id, write_target) + + +def _handle_close_session(write_target): + while True: + session_id = yield + close_session(session_id) + encode_session_close_message(session_id, write_target) + + +def _handle_unknown_session(): + while True: + yield # TODO + + +class UnexpectedMessageError(Exception): + pass + + +def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, future): + if msg_type in exp_types: + pbuf_type = get_protobuf_type(msg_type) + builder = build_protobuf_message(pbuf_type, future.resolve) + builder.send(None) + return pbuf_type.load(builder) + else: + future.resolve(UnexpectedMessageError(msg_type)) + return _handle_registered_type(msg_type, data_len, session_id) + + +def _handle_registered_type(msg_type, data_len, session_id): + genfunc, args = _workflow_genfuncs.get( + msg_type, (_handle_unexpected_type, ())) + return genfunc(msg_type, data_len, session_id, *args) + + +def _handle_unexpected_type(msg_type, data_len, session_id): + log.info(__name__, 'skipping message %d of len %d' % (msg_type, data_len)) + try: + while True: + yield + except EOFError: + pass diff --git a/src/trezor/wire/wire_codec.py b/src/trezor/wire/wire_codec.py new file mode 100644 index 000000000..d9795dfb0 --- /dev/null +++ b/src/trezor/wire/wire_codec.py @@ -0,0 +1,173 @@ +import ustruct +import ubinascii + +# trezor wire protocol #2: +# +# # hid report (64B) +# - report marker (1B) +# - session id (4B, BE) +# - payload (59B) +# +# # message +# - streamed as payloads of hid reports +# - message type (4B, BE) +# - data length (4B, BE) +# - data (var-length) +# - data crc32 checksum (4B, BE) +# +# # sessions +# - reports are interleaved, need to be dispatched by session id + +REP_MARKER_HEADER = const(72) # ord('H') +REP_MARKER_DATA = const(68) # ord('D') +REP_MARKER_OPEN = const(79) # ord('O') +REP_MARKER_CLOSE = const(67) # ord('C') + +_REP_HEADER = '>BL' # marker, session id +_MSG_HEADER = '>LL' # msg type, data length +_MSG_FOOTER = '>L' # data checksum + +_REP_LEN = const(64) +_REP_HEADER_LEN = ustruct.calcsize(_REP_HEADER) +_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER) +_MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER) + + +def parse_report(data): + marker, session_id = ustruct.unpack(_REP_HEADER, data) + # TODO: handle v1 protocol + return marker, session_id, data[_REP_HEADER_LEN:] + + +def parse_message(data): + msg_type, data_len = ustruct.unpack(_MSG_HEADER, data) + return msg_type, data_len, data[_MSG_HEADER_LEN:] + + +def parse_message_footer(data): + data_checksum, = ustruct.unpack(_MSG_FOOTER, data) + return data_checksum, + + +def serialize_report_header(data, marker, session_id): + ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id) + + +def serialize_message_header(data, msg_type, msg_len): + ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len) + + +def serialize_message_footer(data, checksum): + ustruct.pack_into(_MSG_FOOTER, data, 0, checksum) + + +def serialize_opened_session(data, session_id): + serialize_report_header(data, REP_MARKER_OPEN, session_id) + + +class MessageChecksumError(Exception): + pass + + +def decode_wire_stream(genfunc, session_id, *args): + '''Decode a wire message from the report data and stream it to target. + +Receives report payloads. +Sends (msg_type, data_len) to target, followed by data chunks. +Throws EOFError after last data chunk, in case of valid checksum. +Throws MessageChecksumError to target if data doesn't match the checksum. +''' + message = yield # read first report + msg_type, data_len, data_tail = parse_message(message) + + target = genfunc(msg_type, data_len, session_id, *args) + target.send(None) + + checksum = 0 # crc32 + nreports = 1 + + compute_checksum = hasattr(ubinascii, 'crc32') + + while data_len > 0: + if nreports > 1: + data_tail = yield # read next report + nreports += 1 + + data_chunk = data_tail[:data_len] # slice off the garbage at the end + data_tail = data_tail[len(data_chunk):] # slice off what we have read + data_len -= len(data_chunk) + target.send(data_chunk) + + if compute_checksum: + checksum = ubinascii.crc32(checksum, data_chunk) + + msg_footer = data_tail[:_MSG_FOOTER_LEN] + if len(msg_footer) < _MSG_FOOTER_LEN: + data_tail = yield # read report with the rest of checksum + msg_footer += data_tail[:_MSG_FOOTER_LEN - len(msg_footer)] + + if compute_checksum: + data_checksum, = parse_message_footer(msg_footer) + else: + data_checksum = checksum + if data_checksum != checksum: + target.throw(MessageChecksumError, 'Message checksum mismatch') + else: + target.throw(EOFError) + + +def encode_wire_message(msg_type, msg_data, session_id, target): + report = bytearray(_REP_LEN) + serialize_report_header(report, REP_MARKER_HEADER, session_id) + serialize_message_header(report, msg_type, len(msg_data)) + + source_data = memoryview(msg_data) + target_data = memoryview(report)[_REP_HEADER_LEN + _MSG_HEADER_LEN:] + + compute_checksum = hasattr(ubinascii, 'crc32') + + if compute_checksum: + checksum = ubinascii.crc32(msg_data) + else: + checksum = 0 + + msg_footer = bytearray(_MSG_FOOTER_LEN) + serialize_message_footer(msg_footer, checksum) + + first = True + + while True: + # move as much as possible from source to target + n = min(len(target_data), len(source_data)) + target_data[:n] = source_data[:n] + source_data = source_data[n:] + target_data = target_data[n:] + + # continue with the footer if source is empty and we have space + if not source_data and target_data and msg_footer: + source_data = msg_footer + msg_footer = None + continue + + target.send(report) + + if not source_data and not msg_footer: + break + + if first: + # reset to skip the magic and session ID + serialize_report_header(report, REP_MARKER_DATA, session_id) + target_data = report[_REP_HEADER_LEN:] + first = False + + +def encode_session_open_message(session_id, target): + report = bytearray(_REP_LEN) + serialize_report_header(report, REP_MARKER_OPEN, session_id) + target.send(report) + + +def encode_session_close_message(session_id, target): + report = bytearray(_REP_LEN) + serialize_report_header(report, REP_MARKER_CLOSE, session_id) + target.send(report) diff --git a/src/trezor/wire/wire_dispatcher.py b/src/trezor/wire/wire_dispatcher.py new file mode 100644 index 000000000..729c555dc --- /dev/null +++ b/src/trezor/wire/wire_dispatcher.py @@ -0,0 +1,40 @@ +from trezor import log +from .wire_codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE + + +def dispatch_reports_by_session(handlers, + open_handler, + close_handler, + fallback_handler): + ''' + Consumes reports adhering to the wire codec and dispatches the report + payloads by between the passed handlers. + ''' + + while True: + marker, session_id, report_data = parse_report((yield)) + + if marker == REP_MARKER_OPEN: + log.debug(__name__, 'request for new session') + open_handler.send(session_id) + continue + + elif marker == REP_MARKER_CLOSE: + log.debug(__name__, 'request for closing session %d', session_id) + close_handler.send(session_id) + continue + + elif session_id in handlers: + log.debug(__name__, 'report on session %d', session_id) + handler = handlers[session_id] + + else: + log.debug(__name__, 'report on unknown session %d', session_id) + handler = fallback_handler + + try: + handler.send(report_data) + except StopIteration: + handlers.pop(session_id) + except Exception as e: + log.exception(__name__, e) diff --git a/src/trezor/wire/wire_io.py b/src/trezor/wire/wire_io.py new file mode 100644 index 000000000..95017a7a8 --- /dev/null +++ b/src/trezor/wire/wire_io.py @@ -0,0 +1,16 @@ +from trezor import msg +from trezor import loop + +_DEFAULT_IFACE = const(0) + + +def read_report_stream(target, iface=_DEFAULT_IFACE): + while True: + report, = yield loop.Select(iface) + target.send(report) + + +def write_report_stream(iface=_DEFAULT_IFACE): + while True: + report = yield + msg.send(iface, report) diff --git a/src/trezor/wire_streaming.py b/src/trezor/wire_streaming.py deleted file mode 100644 index fd65e12ca..000000000 --- a/src/trezor/wire_streaming.py +++ /dev/null @@ -1,166 +0,0 @@ -import ustruct -import ubinascii - -from . import msg -from . import loop -from .crypto import random - - -MESSAGE_IFACE = const(0) -EMPTY_SESSION = const(0) - -sessions = {} - - -def generate_session_id(): - return random.uniform(0xffffffff) + 1 - - -async def dispatch_reports(): - while True: - report = await _read_report() - session_id, report_data = _parse_report(report) - sessions[session_id].send(report_data) - - -async def read_session_message(session_id, types): - future = loop.Future() - pbuf_decoder = _decode_protobuf_message(types, future) - wire_decoder = _decode_wire_message(pbuf_decoder) - assert session_id not in sessions - sessions[session_id] = wire_decoder - try: - result = await future - finally: - del sessions[session_id] - return result - - -def lookup_protobuf_type(msg_type, pbuf_types): - for pt in pbuf_types: - if pt.wire_type == msg_type: - return pt - return None - - -def _decode_protobuf_message(types, future): - msg_type, _ = yield - pbuf_type = lookup_protobuf_type(msg_type, types) - target = build_protobuf_message(pbuf_type, future) - yield from pbuf_type.load(AsyncBytearrayReader(), target) - - -class AsyncBytearrayReader: - - def __init__(self, buf=None, n=None): - self.buf = buf if buf is not None else bytearray() - self.n = n - - def read(self, n): - if self.n is not None: - self.n -= n - if self.n <= 0: - raise EOFError() - buf = self.buf - while len(buf) < n: - buf.extend((yield)) # buffer next data chunk - result, buf[:] = buf[:n], buf[n:] - return result - - def limit(self, n): - return AsyncBytearrayReader(self.buf, n) - - -async def _read_report(): - report, = await loop.Select(MESSAGE_IFACE) - return memoryview(report) # make slicing cheap - - -async def _write_report(report): - return msg.send(MESSAGE_IFACE, report) - - -# TREZOR wire protocol v2: -# -# HID report (64B): -# - report magic (1B) -# - session (4B, BE) -# - payload (59B) -# -# message: -# - streamed as payloads of HID reports: -# - message type (4B, BE) -# - data length (4B, BE) -# - data (var-length) -# - data checksum (4B, BE) - - -REP_HEADER = '>BL' # marker, session id -MSG_HEADER = '>LL' # msg type, data length -MSG_FOOTER = '>L' # data checksum - -REP_HEADER_LEN = ustruct.calcsize(REP_HEADER) -MSG_HEADER_LEN = ustruct.calcsize(MSG_HEADER) -MSG_FOOTER_LEN = ustruct.calcsize(MSG_FOOTER) - - -class MessageChecksumError(Exception): - pass - - -def _parse_report(data): - marker, session_id = ustruct.parse(REP_HEADER, data) - return session_id, data[REP_HEADER_LEN:] - - -def _parse_message(data): - msg_type, data_len = ustruct.parse(MSG_HEADER, data) - return msg_type, data_len, data[MSG_HEADER_LEN:] - - -def _parse_footer(data): - data_checksum, = ustruct.parse(MSG_FOOTER, data) - return data_checksum, - - -def _decode_wire_message(target): - '''Decode a wire message from the report data and stream it to target. - -Receives report payloads. -Sends (msg_type, data_len) to target, followed by data chunks. -Throws EOFError after last data chunk, in case of valid checksum. -Throws MessageChecksumError to target if data doesn't match the checksum. -''' - message = (yield) # read first report - msg_type, data_len, data_tail = _parse_message(message) - target.send((msg_type, data_len)) - - checksum = 0 # crc32 - nreports = 1 - - while data_len > 0: - if nreports > 1: - data_tail = (yield) # read next report - nreports += 1 - - data_chunk = data_tail[:data_len] # slice off the garbage at the end - data_tail = data_tail[len(data_chunk):] # slice off what we have read - data_len -= len(data_chunk) - target.send(data_chunk) - - checksum = ubinascii.crc32(checksum, data_chunk) - - data_footer = data_tail[:MSG_FOOTER_LEN] - if len(data_footer) < MSG_FOOTER_LEN: - data_tail = (yield) # read report with the rest of checksum - data_footer += data_tail[:MSG_FOOTER_LEN - len(data_footer)] - - data_checksum, = _parse_footer(data_footer) - if data_checksum != checksum: - target.throw(MessageChecksumError, 'Message checksum mismatch') - else: - target.throw(EOFError) - - -def _encode_message(target): - pass diff --git a/src/trezor/workflow.py b/src/trezor/workflow.py new file mode 100644 index 000000000..4766dd975 --- /dev/null +++ b/src/trezor/workflow.py @@ -0,0 +1,37 @@ +from trezor import log, loop + +_started_workflows = [] +_default_workflow = None +_default_workflow_genfunc = None + + +def start_default(genfunc): + global _default_workflow + global _default_workflow_genfunc + _default_workflow_genfunc = genfunc + _default_workflow = _default_workflow_genfunc() + log.info(__name__, 'starting default workflow %s', _default_workflow) + loop.schedule_task(_default_workflow) + + +def start_workflow(workflow): + global _default_workflow + if _default_workflow is not None: + log.info(__name__, 'closing default workflow %s', _default_workflow) + _default_workflow.close() + _default_workflow = None + + log.info(__name__, 'starting workflow %s', workflow) + _started_workflows.append(workflow) + loop.schedule_task(watch_workflow(workflow)) + + +async def watch_workflow(workflow): + global _default_workflow + try: + return await workflow + finally: + _started_workflows.remove(workflow) + + if not _started_workflows and _default_workflow_genfunc is not None: + start_default(_default_workflow_genfunc)