mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-03 11:20:59 +00:00
session/workflow overhaul
- lazy loading and importing of protobuf messages - workflow dispatching through codec pipeline on the first message report HACK: workflow killing TODO: cap on concurrent sessions TODO: ui locking/queuing TODO: session storage TODO: tests
This commit is contained in:
parent
99485b3385
commit
099767d592
@ -1,17 +1,29 @@
|
|||||||
from trezor.dispatcher import register
|
from trezor.wire import register_type, protobuf_handler, write_message
|
||||||
from trezor.utils import unimport_func
|
from trezor.utils import unimport
|
||||||
|
from trezor.messages.wire_types import Initialize
|
||||||
|
|
||||||
|
|
||||||
@unimport_func
|
@unimport
|
||||||
def dispatch_Initialize(mtype, mbuf):
|
async def dispatch_Initialize(_, session_id):
|
||||||
from trezor.messages.Initialize import Initialize
|
from trezor.messages.Features import Features
|
||||||
|
features = Features(
|
||||||
message = Initialize.loads(mbuf)
|
revision='deadbeef',
|
||||||
|
bootloader_hash='deadbeef',
|
||||||
from .layout_homescreen import layout_homescreen
|
device_id='DEADBEEF',
|
||||||
return layout_homescreen(message)
|
coins=[],
|
||||||
|
imported=False,
|
||||||
|
initialized=False,
|
||||||
|
label='My TREZOR',
|
||||||
|
major_version=2,
|
||||||
|
minor_version=0,
|
||||||
|
patch_version=0,
|
||||||
|
pin_cached=False,
|
||||||
|
pin_protection=True,
|
||||||
|
passphrase_cached=False,
|
||||||
|
passphrase_protection=False,
|
||||||
|
vendor='bitcointrezor.com')
|
||||||
|
await write_message(session_id, features)
|
||||||
|
|
||||||
|
|
||||||
def boot():
|
def boot():
|
||||||
Initialize = 0
|
register_type(Initialize, protobuf_handler, dispatch_Initialize)
|
||||||
register(Initialize, dispatch_Initialize)
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from trezor import ui, dispatcher, loop, res, wire
|
from trezor import ui, loop, res
|
||||||
from trezor.ui.swipe import Swipe
|
from trezor.ui.swipe import Swipe
|
||||||
from trezor.utils import unimport_gen
|
|
||||||
|
|
||||||
|
|
||||||
async def swipe_to_rotate():
|
async def swipe_to_rotate():
|
||||||
@ -10,37 +9,14 @@ async def swipe_to_rotate():
|
|||||||
|
|
||||||
|
|
||||||
async def animate_logo():
|
async def animate_logo():
|
||||||
# def func(foreground):
|
icon = res.load('apps/homescreen/res/trezor.toig')
|
||||||
# ui.display.icon(0, 0, res.load(
|
|
||||||
# 'apps/homescreen/res/trezor.toig'), foreground, ui.BLACK)
|
|
||||||
# await ui.animate_pulse(func, ui.WHITE, ui.GREY, speed=400000)
|
|
||||||
|
|
||||||
async for fg in ui.pulse_animation(ui.WHITE, ui.GREY, speed=400000):
|
async for fg in ui.pulse_animation(ui.WHITE, ui.GREY, speed=400000):
|
||||||
icon = res.load('apps/homescreen/res/trezor.toig')
|
|
||||||
ui.display.icon(0, 0, icon, fg, ui.BLACK)
|
ui.display.icon(0, 0, icon, fg, ui.BLACK)
|
||||||
|
|
||||||
|
|
||||||
@unimport_gen
|
async def layout_homescreen():
|
||||||
async def layout_homescreen(initialize_msg=None):
|
wait = loop.Wait([swipe_to_rotate(), animate_logo()])
|
||||||
if initialize_msg is not None:
|
try:
|
||||||
from trezor.messages.Features import Features
|
await wait
|
||||||
features = Features()
|
finally:
|
||||||
features.revision = 'deadbeef'
|
wait.exit()
|
||||||
features.bootloader_hash = 'deadbeef'
|
|
||||||
features.device_id = 'DEADBEEF'
|
|
||||||
features.coins = []
|
|
||||||
features.imported = False
|
|
||||||
features.initialized = False
|
|
||||||
features.label = 'My TREZOR'
|
|
||||||
features.major_version = 2
|
|
||||||
features.minor_version = 0
|
|
||||||
features.patch_version = 0
|
|
||||||
features.pin_cached = False
|
|
||||||
features.pin_protection = True
|
|
||||||
features.passphrase_cached = False
|
|
||||||
features.passphrase_protection = False
|
|
||||||
features.vendor = 'bitcointrezor.com'
|
|
||||||
await wire.write(features)
|
|
||||||
await loop.Wait([dispatcher.dispatch(),
|
|
||||||
swipe_to_rotate(),
|
|
||||||
animate_logo()])
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import trezor.main
|
import trezor.main
|
||||||
from trezor import msg
|
from trezor import msg
|
||||||
|
from trezor import ui
|
||||||
|
from trezor import wire
|
||||||
|
|
||||||
# Load all applications
|
# Load all applications
|
||||||
from apps import playground
|
from apps import playground
|
||||||
@ -14,13 +16,16 @@ management.boot()
|
|||||||
wallet.boot()
|
wallet.boot()
|
||||||
|
|
||||||
# Change backlight to white for better visibility
|
# Change backlight to white for better visibility
|
||||||
trezor.ui.display.backlight(255)
|
ui.display.backlight(255)
|
||||||
|
|
||||||
# Just a demo to show how to register USB ifaces
|
# Just a demo to show how to register USB ifaces
|
||||||
msg.setup([(1, 0xF53C), (2, 0xF1D0)])
|
msg.setup([(1, 0xF53C), (2, 0xF1D0)])
|
||||||
|
|
||||||
|
# Initialize the wire codec pipeline
|
||||||
|
wire.setup()
|
||||||
|
|
||||||
# Load default homescreen
|
# Load default homescreen
|
||||||
from apps.homescreen.layout_homescreen import layout_homescreen
|
from apps.homescreen.layout_homescreen import layout_homescreen
|
||||||
|
|
||||||
# Run main even loop and specify, which screen is default
|
# Run main even loop and specify, which screen is default
|
||||||
trezor.main.run(main_layout=layout_homescreen)
|
trezor.main.run(default_workflow=layout_homescreen)
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
from . import wire
|
|
||||||
from . import layout
|
|
||||||
|
|
||||||
|
|
||||||
message_handlers = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register(mtype, handler):
|
|
||||||
if mtype in message_handlers:
|
|
||||||
raise Exception('Message wire type %s is already registered', mtype)
|
|
||||||
message_handlers[mtype] = handler
|
|
||||||
|
|
||||||
|
|
||||||
def unregister(mtype):
|
|
||||||
del message_handlers[mtype]
|
|
||||||
|
|
||||||
|
|
||||||
def dispatch():
|
|
||||||
_, mtype, mbuf = yield from wire.read_wire_msg()
|
|
||||||
handler = message_handlers[mtype]
|
|
||||||
layout.change(handler(mtype, mbuf))
|
|
@ -1,33 +0,0 @@
|
|||||||
import utime
|
|
||||||
|
|
||||||
from . import log
|
|
||||||
from . import utils
|
|
||||||
|
|
||||||
|
|
||||||
class ChangeLayoutException(Exception):
|
|
||||||
|
|
||||||
def __init__(self, layout):
|
|
||||||
self.layout = layout
|
|
||||||
|
|
||||||
|
|
||||||
def change(layout):
|
|
||||||
raise ChangeLayoutException(layout)
|
|
||||||
|
|
||||||
|
|
||||||
def set_main(main_layout):
|
|
||||||
layout = main_layout()
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
layout = yield from layout
|
|
||||||
except ChangeLayoutException as e:
|
|
||||||
layout = e.layout
|
|
||||||
except Exception as e:
|
|
||||||
log.exception(__name__, e)
|
|
||||||
utime.sleep(1) # Don't produce wall of exceptions
|
|
||||||
|
|
||||||
if not isinstance(layout, utils.type_gen):
|
|
||||||
log.info(__name__, 'Switching to main layout %s', main_layout)
|
|
||||||
layout = main_layout()
|
|
||||||
else:
|
|
||||||
log.info(__name__, 'Switching to proposed layout %s', layout)
|
|
@ -4,7 +4,7 @@ sys.path.append('lib')
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
from trezor import loop
|
from trezor import loop
|
||||||
from trezor import layout
|
from trezor import workflow
|
||||||
from trezor import log
|
from trezor import log
|
||||||
|
|
||||||
log.level = log.INFO
|
log.level = log.INFO
|
||||||
@ -20,7 +20,7 @@ def perf_info_debug():
|
|||||||
mem_alloc = gc.mem_alloc()
|
mem_alloc = gc.mem_alloc()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
log.info(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue: %s",
|
log.info(__name__, "mem_alloc: %s/%s, delay_avg: %d, delay_last: %d, queue: %s",
|
||||||
mem_alloc, gc.mem_alloc(), delay_avg, delay_last, ', '.join(queue))
|
mem_alloc, gc.mem_alloc(), delay_avg, delay_last, ', '.join(queue))
|
||||||
|
|
||||||
yield loop.Sleep(1000000)
|
yield loop.Sleep(1000000)
|
||||||
|
|
||||||
@ -32,10 +32,10 @@ def perf_info():
|
|||||||
yield loop.Sleep(1000000)
|
yield loop.Sleep(1000000)
|
||||||
|
|
||||||
|
|
||||||
def run(main_layout):
|
def run(default_workflow):
|
||||||
if __debug__:
|
if __debug__:
|
||||||
loop.schedule_task(perf_info_debug())
|
loop.schedule_task(perf_info_debug())
|
||||||
else:
|
else:
|
||||||
loop.schedule_task(perf_info())
|
loop.schedule_task(perf_info())
|
||||||
loop.schedule_task(layout.set_main(main_layout))
|
workflow.start_default(default_workflow)
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
@ -0,0 +1,13 @@
|
|||||||
|
from . import wire_types
|
||||||
|
|
||||||
|
|
||||||
|
def get_protobuf_type_name(wire_type):
|
||||||
|
for name in dir(wire_types):
|
||||||
|
if getattr(wire_types, name) == wire_type:
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def get_protobuf_type(wire_type):
|
||||||
|
name = get_protobuf_type_name(wire_type)
|
||||||
|
module = __import__('.%s' % name, globals(), locals(), (name,), 1)
|
||||||
|
return getattr(module, name)
|
@ -1,131 +0,0 @@
|
|||||||
import ustruct
|
|
||||||
import ubinascii
|
|
||||||
from . import msg
|
|
||||||
from . import loop
|
|
||||||
from . import log
|
|
||||||
|
|
||||||
IFACE = const(0)
|
|
||||||
|
|
||||||
# TREZOR wire protocol v2:
|
|
||||||
#
|
|
||||||
# HID report = 64 bytes, padded with 0x0
|
|
||||||
# First report = !SSSSTTTTLLLLD...
|
|
||||||
# Next reports = #SSSSD...CCCC
|
|
||||||
#
|
|
||||||
# S = session id
|
|
||||||
# T = message type
|
|
||||||
# L = data length
|
|
||||||
# D = data
|
|
||||||
# C = data checksum - crc32
|
|
||||||
|
|
||||||
_REPORT_LEN = const(64)
|
|
||||||
_MAX_DATA_LEN = const(65536)
|
|
||||||
_HEADER_MAGIC = const(33) # ord('!')
|
|
||||||
_DATA_MAGIC = const(35) # ord('#')
|
|
||||||
|
|
||||||
|
|
||||||
def _read_report():
|
|
||||||
rep, = yield loop.Select(IFACE)
|
|
||||||
assert len(rep) == _REPORT_LEN, 'HID read failed'
|
|
||||||
return memoryview(rep)
|
|
||||||
|
|
||||||
|
|
||||||
def _write_report(rep):
|
|
||||||
size = msg.send(IFACE, rep)
|
|
||||||
assert size == _REPORT_LEN, 'HID write failed'
|
|
||||||
yield # just to be a generator
|
|
||||||
|
|
||||||
|
|
||||||
def read_wire_msg():
|
|
||||||
|
|
||||||
rep = yield from _read_report()
|
|
||||||
magic, sid, mtype, mlen = ustruct.unpack('>BLLL', rep)
|
|
||||||
assert magic == _HEADER_MAGIC, 'Incorrect report magic'
|
|
||||||
assert mlen < _MAX_DATA_LEN, 'Message too large to read'
|
|
||||||
|
|
||||||
mlen += 4 # Account for the checksum
|
|
||||||
data = rep[13:][:mlen] # Skip magic and header, trim to data len
|
|
||||||
remaining = mlen - len(data)
|
|
||||||
# Avoid the copy if we don't append
|
|
||||||
buffered = bytearray(data) if remaining > 0 else data
|
|
||||||
|
|
||||||
while remaining > 0:
|
|
||||||
rep = yield from _read_report()
|
|
||||||
magic, rsid = ustruct.unpack('>BL', rep)
|
|
||||||
assert magic == _DATA_MAGIC, 'Incorrect report magic'
|
|
||||||
assert rsid == sid, 'Session ID mismatch'
|
|
||||||
|
|
||||||
data = rep[5:][:remaining] # Skip magic and session ID, trim
|
|
||||||
buffered.extend(data)
|
|
||||||
remaining -= len(data)
|
|
||||||
|
|
||||||
# Split to data and checksum
|
|
||||||
mbuf = buffered[:-4]
|
|
||||||
csum = ustruct.unpack_from('>L', buffered, -4)
|
|
||||||
|
|
||||||
# Compare the checksums
|
|
||||||
if hasattr(ubinascii, 'crc32'):
|
|
||||||
assert csum == ubinascii.crc32(mbuf), 'Message checksum mismatch'
|
|
||||||
|
|
||||||
return sid, mtype, mbuf
|
|
||||||
|
|
||||||
|
|
||||||
def write_wire_msg(sid, mtype, mbuf):
|
|
||||||
|
|
||||||
rep = bytearray(_REPORT_LEN)
|
|
||||||
ustruct.pack_into('>BLLL', rep, 0, _HEADER_MAGIC, sid, mtype, len(mbuf))
|
|
||||||
|
|
||||||
rep = memoryview(rep)
|
|
||||||
mbuf = memoryview(mbuf)
|
|
||||||
data = rep[13:] # Skip magic and header
|
|
||||||
|
|
||||||
if hasattr(ubinascii, 'crc32'):
|
|
||||||
csum = ubinascii.crc32(mbuf)
|
|
||||||
else:
|
|
||||||
csum = 0
|
|
||||||
footer = ustruct.pack('>L', csum)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
n = min(len(data), len(mbuf))
|
|
||||||
data[:n] = mbuf[:n] # Copy as much data as possible from mbuf to data
|
|
||||||
mbuf = mbuf[n:] # Skip written bytes
|
|
||||||
data = data[n:] # Skip written bytes
|
|
||||||
|
|
||||||
# Continue with the footer if mbuf is empty and we have space
|
|
||||||
if not mbuf and footer and data:
|
|
||||||
mbuf = footer
|
|
||||||
footer = None
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield from _write_report(rep)
|
|
||||||
if not mbuf:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Reset to skip the magic and session ID
|
|
||||||
rep[0] = _DATA_MAGIC
|
|
||||||
data = rep[5:]
|
|
||||||
|
|
||||||
|
|
||||||
def read(*types):
|
|
||||||
if __debug__:
|
|
||||||
log.debug(__name__, 'Reading one of %s', types)
|
|
||||||
_, mtype, mbuf = yield from read_wire_msg()
|
|
||||||
for t in types:
|
|
||||||
if t.wire_type == mtype:
|
|
||||||
return t.loads(mbuf)
|
|
||||||
else:
|
|
||||||
raise Exception('Unexpected message')
|
|
||||||
|
|
||||||
|
|
||||||
def write(m):
|
|
||||||
if __debug__:
|
|
||||||
log.debug(__name__, 'Writing %s', m)
|
|
||||||
mbuf = m.dumps()
|
|
||||||
mtype = m.message_type.wire_type
|
|
||||||
yield from write_wire_msg(0, mtype, mbuf)
|
|
||||||
|
|
||||||
|
|
||||||
def call(req, *types):
|
|
||||||
yield from write(req)
|
|
||||||
res = yield from read(*types)
|
|
||||||
return res
|
|
155
src/trezor/wire/__init__.py
Normal file
155
src/trezor/wire/__init__.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from protobuf.protobuf import build_protobuf_message
|
||||||
|
|
||||||
|
from trezor.loop import schedule_task, Future
|
||||||
|
from trezor.crypto import random
|
||||||
|
from trezor.messages import get_protobuf_type
|
||||||
|
from trezor.workflow import start_workflow
|
||||||
|
from trezor import log
|
||||||
|
|
||||||
|
from .wire_io import read_report_stream, write_report_stream
|
||||||
|
from .wire_dispatcher import dispatch_reports_by_session
|
||||||
|
from .wire_codec import \
|
||||||
|
decode_wire_stream, encode_wire_message, \
|
||||||
|
encode_session_open_message, encode_session_close_message
|
||||||
|
|
||||||
|
_session_handlers = {} # session id -> generator
|
||||||
|
_workflow_genfuncs = {} # wire type -> (generator function, args)
|
||||||
|
_opened_sessions = set() # session ids
|
||||||
|
|
||||||
|
|
||||||
|
def generate_session_id():
|
||||||
|
while True:
|
||||||
|
session_id = random.uniform(0x0fffffff) + 1
|
||||||
|
if session_id not in _opened_sessions:
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def open_session():
|
||||||
|
session_id = generate_session_id()
|
||||||
|
_opened_sessions.add(session_id)
|
||||||
|
log.info(__name__, 'opened session %d: %s', session_id, _opened_sessions)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def close_session(session_id):
|
||||||
|
_opened_sessions.discard(session_id)
|
||||||
|
_session_handlers.pop(session_id, None)
|
||||||
|
log.info(__name__, 'closed session %d: %s', session_id, _opened_sessions)
|
||||||
|
|
||||||
|
|
||||||
|
def register_type(wire_type, genfunc, *args):
|
||||||
|
if wire_type in _workflow_genfuncs:
|
||||||
|
raise KeyError('message of type %d already registered' % wire_type)
|
||||||
|
log.info(__name__, 'registering %s for type %d',
|
||||||
|
(genfunc, args), wire_type)
|
||||||
|
_workflow_genfuncs[wire_type] = (genfunc, args)
|
||||||
|
|
||||||
|
|
||||||
|
def register_session(session_id, handler):
|
||||||
|
if session_id not in _opened_sessions:
|
||||||
|
raise KeyError('session %d is unknown' % session_id)
|
||||||
|
if session_id in _session_handlers:
|
||||||
|
raise KeyError('session %d is already registered' % session_id)
|
||||||
|
log.info(__name__, 'registering %s for session %d', handler, session_id)
|
||||||
|
_session_handlers[session_id] = handler
|
||||||
|
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
report_writer = write_report_stream()
|
||||||
|
report_writer.send(None)
|
||||||
|
|
||||||
|
open_session_handler = _handle_open_session(report_writer)
|
||||||
|
open_session_handler.send(None)
|
||||||
|
|
||||||
|
close_session_handler = _handle_close_session(report_writer)
|
||||||
|
close_session_handler.send(None)
|
||||||
|
|
||||||
|
fallback_session_handler = _handle_unknown_session()
|
||||||
|
fallback_session_handler.send(None)
|
||||||
|
|
||||||
|
session_dispatcher = dispatch_reports_by_session(
|
||||||
|
_session_handlers,
|
||||||
|
open_session_handler,
|
||||||
|
close_session_handler,
|
||||||
|
fallback_session_handler)
|
||||||
|
session_dispatcher.send(None)
|
||||||
|
|
||||||
|
schedule_task(read_report_stream(session_dispatcher))
|
||||||
|
|
||||||
|
|
||||||
|
async def read_message(session_id, *exp_types):
|
||||||
|
future = Future()
|
||||||
|
wire_decoder = decode_wire_stream(
|
||||||
|
_dispatch_and_build_protobuf, session_id, exp_types, future)
|
||||||
|
wire_decoder.send(None)
|
||||||
|
register_session(session_id, wire_decoder)
|
||||||
|
return await future
|
||||||
|
|
||||||
|
|
||||||
|
async def write_message(session_id, pbuf_message):
|
||||||
|
msg_data = await pbuf_message.dumps()
|
||||||
|
msg_type = pbuf_message.message_type.wire_type
|
||||||
|
writer = write_report_stream()
|
||||||
|
writer.send(None)
|
||||||
|
encode_wire_message(msg_type, msg_data, session_id, writer)
|
||||||
|
|
||||||
|
|
||||||
|
def protobuf_handler(msg_type, data_len, session_id, callback, *args):
|
||||||
|
def finalizer(message):
|
||||||
|
start_workflow(callback(message, session_id, *args))
|
||||||
|
pbuf_type = get_protobuf_type(msg_type)
|
||||||
|
builder = build_protobuf_message(pbuf_type, finalizer)
|
||||||
|
builder.send(None)
|
||||||
|
return pbuf_type.load(builder)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_open_session(write_target):
|
||||||
|
while True:
|
||||||
|
yield
|
||||||
|
session_id = open_session()
|
||||||
|
wire_decoder = decode_wire_stream(_handle_registered_type, session_id)
|
||||||
|
wire_decoder.send(None)
|
||||||
|
register_session(session_id, wire_decoder)
|
||||||
|
encode_session_open_message(session_id, write_target)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_close_session(write_target):
|
||||||
|
while True:
|
||||||
|
session_id = yield
|
||||||
|
close_session(session_id)
|
||||||
|
encode_session_close_message(session_id, write_target)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_unknown_session():
|
||||||
|
while True:
|
||||||
|
yield # TODO
|
||||||
|
|
||||||
|
|
||||||
|
class UnexpectedMessageError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, future):
|
||||||
|
if msg_type in exp_types:
|
||||||
|
pbuf_type = get_protobuf_type(msg_type)
|
||||||
|
builder = build_protobuf_message(pbuf_type, future.resolve)
|
||||||
|
builder.send(None)
|
||||||
|
return pbuf_type.load(builder)
|
||||||
|
else:
|
||||||
|
future.resolve(UnexpectedMessageError(msg_type))
|
||||||
|
return _handle_registered_type(msg_type, data_len, session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_registered_type(msg_type, data_len, session_id):
|
||||||
|
genfunc, args = _workflow_genfuncs.get(
|
||||||
|
msg_type, (_handle_unexpected_type, ()))
|
||||||
|
return genfunc(msg_type, data_len, session_id, *args)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_unexpected_type(msg_type, data_len, session_id):
|
||||||
|
log.info(__name__, 'skipping message %d of len %d' % (msg_type, data_len))
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
yield
|
||||||
|
except EOFError:
|
||||||
|
pass
|
173
src/trezor/wire/wire_codec.py
Normal file
173
src/trezor/wire/wire_codec.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
import ustruct
|
||||||
|
import ubinascii
|
||||||
|
|
||||||
|
# trezor wire protocol #2:
|
||||||
|
#
|
||||||
|
# # hid report (64B)
|
||||||
|
# - report marker (1B)
|
||||||
|
# - session id (4B, BE)
|
||||||
|
# - payload (59B)
|
||||||
|
#
|
||||||
|
# # message
|
||||||
|
# - streamed as payloads of hid reports
|
||||||
|
# - message type (4B, BE)
|
||||||
|
# - data length (4B, BE)
|
||||||
|
# - data (var-length)
|
||||||
|
# - data crc32 checksum (4B, BE)
|
||||||
|
#
|
||||||
|
# # sessions
|
||||||
|
# - reports are interleaved, need to be dispatched by session id
|
||||||
|
|
||||||
|
REP_MARKER_HEADER = const(72) # ord('H')
|
||||||
|
REP_MARKER_DATA = const(68) # ord('D')
|
||||||
|
REP_MARKER_OPEN = const(79) # ord('O')
|
||||||
|
REP_MARKER_CLOSE = const(67) # ord('C')
|
||||||
|
|
||||||
|
_REP_HEADER = '>BL' # marker, session id
|
||||||
|
_MSG_HEADER = '>LL' # msg type, data length
|
||||||
|
_MSG_FOOTER = '>L' # data checksum
|
||||||
|
|
||||||
|
_REP_LEN = const(64)
|
||||||
|
_REP_HEADER_LEN = ustruct.calcsize(_REP_HEADER)
|
||||||
|
_MSG_HEADER_LEN = ustruct.calcsize(_MSG_HEADER)
|
||||||
|
_MSG_FOOTER_LEN = ustruct.calcsize(_MSG_FOOTER)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_report(data):
|
||||||
|
marker, session_id = ustruct.unpack(_REP_HEADER, data)
|
||||||
|
# TODO: handle v1 protocol
|
||||||
|
return marker, session_id, data[_REP_HEADER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_message(data):
|
||||||
|
msg_type, data_len = ustruct.unpack(_MSG_HEADER, data)
|
||||||
|
return msg_type, data_len, data[_MSG_HEADER_LEN:]
|
||||||
|
|
||||||
|
|
||||||
|
def parse_message_footer(data):
|
||||||
|
data_checksum, = ustruct.unpack(_MSG_FOOTER, data)
|
||||||
|
return data_checksum,
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_report_header(data, marker, session_id):
|
||||||
|
ustruct.pack_into(_REP_HEADER, data, 0, marker, session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_message_header(data, msg_type, msg_len):
|
||||||
|
ustruct.pack_into(_MSG_HEADER, data, _REP_HEADER_LEN, msg_type, msg_len)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_message_footer(data, checksum):
|
||||||
|
ustruct.pack_into(_MSG_FOOTER, data, 0, checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_opened_session(data, session_id):
|
||||||
|
serialize_report_header(data, REP_MARKER_OPEN, session_id)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageChecksumError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def decode_wire_stream(genfunc, session_id, *args):
|
||||||
|
'''Decode a wire message from the report data and stream it to target.
|
||||||
|
|
||||||
|
Receives report payloads.
|
||||||
|
Sends (msg_type, data_len) to target, followed by data chunks.
|
||||||
|
Throws EOFError after last data chunk, in case of valid checksum.
|
||||||
|
Throws MessageChecksumError to target if data doesn't match the checksum.
|
||||||
|
'''
|
||||||
|
message = yield # read first report
|
||||||
|
msg_type, data_len, data_tail = parse_message(message)
|
||||||
|
|
||||||
|
target = genfunc(msg_type, data_len, session_id, *args)
|
||||||
|
target.send(None)
|
||||||
|
|
||||||
|
checksum = 0 # crc32
|
||||||
|
nreports = 1
|
||||||
|
|
||||||
|
compute_checksum = hasattr(ubinascii, 'crc32')
|
||||||
|
|
||||||
|
while data_len > 0:
|
||||||
|
if nreports > 1:
|
||||||
|
data_tail = yield # read next report
|
||||||
|
nreports += 1
|
||||||
|
|
||||||
|
data_chunk = data_tail[:data_len] # slice off the garbage at the end
|
||||||
|
data_tail = data_tail[len(data_chunk):] # slice off what we have read
|
||||||
|
data_len -= len(data_chunk)
|
||||||
|
target.send(data_chunk)
|
||||||
|
|
||||||
|
if compute_checksum:
|
||||||
|
checksum = ubinascii.crc32(checksum, data_chunk)
|
||||||
|
|
||||||
|
msg_footer = data_tail[:_MSG_FOOTER_LEN]
|
||||||
|
if len(msg_footer) < _MSG_FOOTER_LEN:
|
||||||
|
data_tail = yield # read report with the rest of checksum
|
||||||
|
msg_footer += data_tail[:_MSG_FOOTER_LEN - len(msg_footer)]
|
||||||
|
|
||||||
|
if compute_checksum:
|
||||||
|
data_checksum, = parse_message_footer(msg_footer)
|
||||||
|
else:
|
||||||
|
data_checksum = checksum
|
||||||
|
if data_checksum != checksum:
|
||||||
|
target.throw(MessageChecksumError, 'Message checksum mismatch')
|
||||||
|
else:
|
||||||
|
target.throw(EOFError)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_wire_message(msg_type, msg_data, session_id, target):
|
||||||
|
report = bytearray(_REP_LEN)
|
||||||
|
serialize_report_header(report, REP_MARKER_HEADER, session_id)
|
||||||
|
serialize_message_header(report, msg_type, len(msg_data))
|
||||||
|
|
||||||
|
source_data = memoryview(msg_data)
|
||||||
|
target_data = memoryview(report)[_REP_HEADER_LEN + _MSG_HEADER_LEN:]
|
||||||
|
|
||||||
|
compute_checksum = hasattr(ubinascii, 'crc32')
|
||||||
|
|
||||||
|
if compute_checksum:
|
||||||
|
checksum = ubinascii.crc32(msg_data)
|
||||||
|
else:
|
||||||
|
checksum = 0
|
||||||
|
|
||||||
|
msg_footer = bytearray(_MSG_FOOTER_LEN)
|
||||||
|
serialize_message_footer(msg_footer, checksum)
|
||||||
|
|
||||||
|
first = True
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# move as much as possible from source to target
|
||||||
|
n = min(len(target_data), len(source_data))
|
||||||
|
target_data[:n] = source_data[:n]
|
||||||
|
source_data = source_data[n:]
|
||||||
|
target_data = target_data[n:]
|
||||||
|
|
||||||
|
# continue with the footer if source is empty and we have space
|
||||||
|
if not source_data and target_data and msg_footer:
|
||||||
|
source_data = msg_footer
|
||||||
|
msg_footer = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
target.send(report)
|
||||||
|
|
||||||
|
if not source_data and not msg_footer:
|
||||||
|
break
|
||||||
|
|
||||||
|
if first:
|
||||||
|
# reset to skip the magic and session ID
|
||||||
|
serialize_report_header(report, REP_MARKER_DATA, session_id)
|
||||||
|
target_data = report[_REP_HEADER_LEN:]
|
||||||
|
first = False
|
||||||
|
|
||||||
|
|
||||||
|
def encode_session_open_message(session_id, target):
|
||||||
|
report = bytearray(_REP_LEN)
|
||||||
|
serialize_report_header(report, REP_MARKER_OPEN, session_id)
|
||||||
|
target.send(report)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_session_close_message(session_id, target):
|
||||||
|
report = bytearray(_REP_LEN)
|
||||||
|
serialize_report_header(report, REP_MARKER_CLOSE, session_id)
|
||||||
|
target.send(report)
|
40
src/trezor/wire/wire_dispatcher.py
Normal file
40
src/trezor/wire/wire_dispatcher.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from trezor import log
|
||||||
|
from .wire_codec import parse_report, REP_MARKER_OPEN, REP_MARKER_CLOSE
|
||||||
|
|
||||||
|
|
||||||
|
def dispatch_reports_by_session(handlers,
|
||||||
|
open_handler,
|
||||||
|
close_handler,
|
||||||
|
fallback_handler):
|
||||||
|
'''
|
||||||
|
Consumes reports adhering to the wire codec and dispatches the report
|
||||||
|
payloads by between the passed handlers.
|
||||||
|
'''
|
||||||
|
|
||||||
|
while True:
|
||||||
|
marker, session_id, report_data = parse_report((yield))
|
||||||
|
|
||||||
|
if marker == REP_MARKER_OPEN:
|
||||||
|
log.debug(__name__, 'request for new session')
|
||||||
|
open_handler.send(session_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif marker == REP_MARKER_CLOSE:
|
||||||
|
log.debug(__name__, 'request for closing session %d', session_id)
|
||||||
|
close_handler.send(session_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif session_id in handlers:
|
||||||
|
log.debug(__name__, 'report on session %d', session_id)
|
||||||
|
handler = handlers[session_id]
|
||||||
|
|
||||||
|
else:
|
||||||
|
log.debug(__name__, 'report on unknown session %d', session_id)
|
||||||
|
handler = fallback_handler
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler.send(report_data)
|
||||||
|
except StopIteration:
|
||||||
|
handlers.pop(session_id)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(__name__, e)
|
16
src/trezor/wire/wire_io.py
Normal file
16
src/trezor/wire/wire_io.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from trezor import msg
|
||||||
|
from trezor import loop
|
||||||
|
|
||||||
|
_DEFAULT_IFACE = const(0)
|
||||||
|
|
||||||
|
|
||||||
|
def read_report_stream(target, iface=_DEFAULT_IFACE):
|
||||||
|
while True:
|
||||||
|
report, = yield loop.Select(iface)
|
||||||
|
target.send(report)
|
||||||
|
|
||||||
|
|
||||||
|
def write_report_stream(iface=_DEFAULT_IFACE):
|
||||||
|
while True:
|
||||||
|
report = yield
|
||||||
|
msg.send(iface, report)
|
@ -1,166 +0,0 @@
|
|||||||
import ustruct
|
|
||||||
import ubinascii
|
|
||||||
|
|
||||||
from . import msg
|
|
||||||
from . import loop
|
|
||||||
from .crypto import random
|
|
||||||
|
|
||||||
|
|
||||||
MESSAGE_IFACE = const(0)
|
|
||||||
EMPTY_SESSION = const(0)
|
|
||||||
|
|
||||||
sessions = {}
|
|
||||||
|
|
||||||
|
|
||||||
def generate_session_id():
|
|
||||||
return random.uniform(0xffffffff) + 1
|
|
||||||
|
|
||||||
|
|
||||||
async def dispatch_reports():
|
|
||||||
while True:
|
|
||||||
report = await _read_report()
|
|
||||||
session_id, report_data = _parse_report(report)
|
|
||||||
sessions[session_id].send(report_data)
|
|
||||||
|
|
||||||
|
|
||||||
async def read_session_message(session_id, types):
|
|
||||||
future = loop.Future()
|
|
||||||
pbuf_decoder = _decode_protobuf_message(types, future)
|
|
||||||
wire_decoder = _decode_wire_message(pbuf_decoder)
|
|
||||||
assert session_id not in sessions
|
|
||||||
sessions[session_id] = wire_decoder
|
|
||||||
try:
|
|
||||||
result = await future
|
|
||||||
finally:
|
|
||||||
del sessions[session_id]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def lookup_protobuf_type(msg_type, pbuf_types):
|
|
||||||
for pt in pbuf_types:
|
|
||||||
if pt.wire_type == msg_type:
|
|
||||||
return pt
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_protobuf_message(types, future):
|
|
||||||
msg_type, _ = yield
|
|
||||||
pbuf_type = lookup_protobuf_type(msg_type, types)
|
|
||||||
target = build_protobuf_message(pbuf_type, future)
|
|
||||||
yield from pbuf_type.load(AsyncBytearrayReader(), target)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncBytearrayReader:
|
|
||||||
|
|
||||||
def __init__(self, buf=None, n=None):
|
|
||||||
self.buf = buf if buf is not None else bytearray()
|
|
||||||
self.n = n
|
|
||||||
|
|
||||||
def read(self, n):
|
|
||||||
if self.n is not None:
|
|
||||||
self.n -= n
|
|
||||||
if self.n <= 0:
|
|
||||||
raise EOFError()
|
|
||||||
buf = self.buf
|
|
||||||
while len(buf) < n:
|
|
||||||
buf.extend((yield)) # buffer next data chunk
|
|
||||||
result, buf[:] = buf[:n], buf[n:]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def limit(self, n):
|
|
||||||
return AsyncBytearrayReader(self.buf, n)
|
|
||||||
|
|
||||||
|
|
||||||
async def _read_report():
|
|
||||||
report, = await loop.Select(MESSAGE_IFACE)
|
|
||||||
return memoryview(report) # make slicing cheap
|
|
||||||
|
|
||||||
|
|
||||||
async def _write_report(report):
|
|
||||||
return msg.send(MESSAGE_IFACE, report)
|
|
||||||
|
|
||||||
|
|
||||||
# TREZOR wire protocol v2:
|
|
||||||
#
|
|
||||||
# HID report (64B):
|
|
||||||
# - report magic (1B)
|
|
||||||
# - session (4B, BE)
|
|
||||||
# - payload (59B)
|
|
||||||
#
|
|
||||||
# message:
|
|
||||||
# - streamed as payloads of HID reports:
|
|
||||||
# - message type (4B, BE)
|
|
||||||
# - data length (4B, BE)
|
|
||||||
# - data (var-length)
|
|
||||||
# - data checksum (4B, BE)
|
|
||||||
|
|
||||||
|
|
||||||
REP_HEADER = '>BL' # marker, session id
|
|
||||||
MSG_HEADER = '>LL' # msg type, data length
|
|
||||||
MSG_FOOTER = '>L' # data checksum
|
|
||||||
|
|
||||||
REP_HEADER_LEN = ustruct.calcsize(REP_HEADER)
|
|
||||||
MSG_HEADER_LEN = ustruct.calcsize(MSG_HEADER)
|
|
||||||
MSG_FOOTER_LEN = ustruct.calcsize(MSG_FOOTER)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageChecksumError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_report(data):
|
|
||||||
marker, session_id = ustruct.parse(REP_HEADER, data)
|
|
||||||
return session_id, data[REP_HEADER_LEN:]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_message(data):
|
|
||||||
msg_type, data_len = ustruct.parse(MSG_HEADER, data)
|
|
||||||
return msg_type, data_len, data[MSG_HEADER_LEN:]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_footer(data):
|
|
||||||
data_checksum, = ustruct.parse(MSG_FOOTER, data)
|
|
||||||
return data_checksum,
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_wire_message(target):
|
|
||||||
'''Decode a wire message from the report data and stream it to target.
|
|
||||||
|
|
||||||
Receives report payloads.
|
|
||||||
Sends (msg_type, data_len) to target, followed by data chunks.
|
|
||||||
Throws EOFError after last data chunk, in case of valid checksum.
|
|
||||||
Throws MessageChecksumError to target if data doesn't match the checksum.
|
|
||||||
'''
|
|
||||||
message = (yield) # read first report
|
|
||||||
msg_type, data_len, data_tail = _parse_message(message)
|
|
||||||
target.send((msg_type, data_len))
|
|
||||||
|
|
||||||
checksum = 0 # crc32
|
|
||||||
nreports = 1
|
|
||||||
|
|
||||||
while data_len > 0:
|
|
||||||
if nreports > 1:
|
|
||||||
data_tail = (yield) # read next report
|
|
||||||
nreports += 1
|
|
||||||
|
|
||||||
data_chunk = data_tail[:data_len] # slice off the garbage at the end
|
|
||||||
data_tail = data_tail[len(data_chunk):] # slice off what we have read
|
|
||||||
data_len -= len(data_chunk)
|
|
||||||
target.send(data_chunk)
|
|
||||||
|
|
||||||
checksum = ubinascii.crc32(checksum, data_chunk)
|
|
||||||
|
|
||||||
data_footer = data_tail[:MSG_FOOTER_LEN]
|
|
||||||
if len(data_footer) < MSG_FOOTER_LEN:
|
|
||||||
data_tail = (yield) # read report with the rest of checksum
|
|
||||||
data_footer += data_tail[:MSG_FOOTER_LEN - len(data_footer)]
|
|
||||||
|
|
||||||
data_checksum, = _parse_footer(data_footer)
|
|
||||||
if data_checksum != checksum:
|
|
||||||
target.throw(MessageChecksumError, 'Message checksum mismatch')
|
|
||||||
else:
|
|
||||||
target.throw(EOFError)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_message(target):
|
|
||||||
pass
|
|
37
src/trezor/workflow.py
Normal file
37
src/trezor/workflow.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from trezor import log, loop
|
||||||
|
|
||||||
|
_started_workflows = []
|
||||||
|
_default_workflow = None
|
||||||
|
_default_workflow_genfunc = None
|
||||||
|
|
||||||
|
|
||||||
|
def start_default(genfunc):
|
||||||
|
global _default_workflow
|
||||||
|
global _default_workflow_genfunc
|
||||||
|
_default_workflow_genfunc = genfunc
|
||||||
|
_default_workflow = _default_workflow_genfunc()
|
||||||
|
log.info(__name__, 'starting default workflow %s', _default_workflow)
|
||||||
|
loop.schedule_task(_default_workflow)
|
||||||
|
|
||||||
|
|
||||||
|
def start_workflow(workflow):
|
||||||
|
global _default_workflow
|
||||||
|
if _default_workflow is not None:
|
||||||
|
log.info(__name__, 'closing default workflow %s', _default_workflow)
|
||||||
|
_default_workflow.close()
|
||||||
|
_default_workflow = None
|
||||||
|
|
||||||
|
log.info(__name__, 'starting workflow %s', workflow)
|
||||||
|
_started_workflows.append(workflow)
|
||||||
|
loop.schedule_task(watch_workflow(workflow))
|
||||||
|
|
||||||
|
|
||||||
|
async def watch_workflow(workflow):
|
||||||
|
global _default_workflow
|
||||||
|
try:
|
||||||
|
return await workflow
|
||||||
|
finally:
|
||||||
|
_started_workflows.remove(workflow)
|
||||||
|
|
||||||
|
if not _started_workflows and _default_workflow_genfunc is not None:
|
||||||
|
start_default(_default_workflow_genfunc)
|
Loading…
Reference in New Issue
Block a user