workflow monitor, reset_device layout work

- request/response manner of usual protobuf workflows is enforced, workflows are expected to either return a valid protobuf response, or raise an exception
- added wire.FailureError exception that allows workflow to provide Failure code & message
- pin workflows simplified

TODO: all this workflow work does not really belong in trezor.wire
pull/25/head
Jan Pochyla 8 years ago committed by Pavol Rusnak
parent 70110187cc
commit 34ed2fb86a
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D

@ -1,4 +1,5 @@
from trezor import wire, ui
from trezor.workflows.request_pin import request_new_pin
from trezor.messages.wire_types import EntropyAck
from trezor.ui.button import Button, CONFIRM_BUTTON, CONFIRM_BUTTON_ACTIVE
from trezor.ui.scroll import paginate, render_scrollbar, animate_swipe
@ -6,75 +7,72 @@ from trezor.crypto import hashlib, random, bip39
from trezor.utils import unimport, chunks
@unimport
async def generate_mnemonic(strength, display_random, session_id):
from trezor.messages.EntropyRequest import EntropyRequest
from trezor.messages.FailureType import Other
await wire.write_message(session_id, EntropyRequest())
ack = await wire.read_message(session_id, EntropyAck)
if strength not in (128, 192, 256):
raise wire.FailureError(Other, 'Invalid seed strength')
# if display_random:
# raise wire.FailureError(Other, 'Entropy display not implemented')
ack = await wire.reply_message(session_id,
EntropyRequest(),
EntropyAck)
strength_bytes = strength // 8
ctx = hashlib.sha256()
ctx.update(random.bytes(32))
ctx.update(ack.entropy)
ctx.update(random.bytes(strength_bytes))
ctx.update(ack.entropy[:strength_bytes])
entropy = ctx.digest()
# TODO: handle strength
# TODO: handle display_random
return bip39.from_data(entropy)
async def request_new_pin():
from trezor.workflows.request_pin import request_pin
pin = await request_pin()
pin_again = await request_pin('Enter PIN again')
if pin == pin_again:
return pin
async def show_mnemonic_page(page, page_count, mnemonic):
ui.clear()
ui.display.text(10, 30, 'Write down your seed',
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
render_scrollbar(page, page_count)
for pi, (wi, word) in enumerate(mnemonic[page]):
top = pi * 30 + 74
pos = wi + 1
ui.display.text_right(40, top, '%d.' % pos,
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
ui.display.text(45, top, '%s' % word,
ui.BOLD, ui.WHITE, ui.BLACK)
if page + 1 == page_count:
await Button((0, 240 - 48, 240, 48), 'Finish',
normal_style=CONFIRM_BUTTON,
active_style=CONFIRM_BUTTON_ACTIVE)
else:
raise Exception() # TODO: wrong PIN should be handled in unified way
await animate_swipe()
async def show_mnemonic(mnemonic):
first_page = const(0)
words_per_page = const(4)
mnemonic_words = list(enumerate(mnemonic.split()))
mnemonic_pages = list(chunks(mnemonic_words, words_per_page))
async def render(page, page_count):
# render header & scrollbar
ui.clear()
ui.display.text(10, 30, 'Write down your seed',
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
render_scrollbar(page, page_count)
# render mnemonic page
for pi, (wi, word) in enumerate(mnemonic_pages[page]):
top = pi * 30 + 74
pos = wi + 1
ui.display.text_right(40, top, '%d.' %
pos, ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
ui.display.text(45, top, '%s' % word, ui.BOLD, ui.WHITE, ui.BLACK)
if page + 1 == page_count:
await Button((0, 240 - 48, 240, 48), 'Finish',
normal_style=CONFIRM_BUTTON,
active_style=CONFIRM_BUTTON_ACTIVE)
else:
await animate_swipe()
await paginate(render, len(mnemonic_pages))
words = list(enumerate(mnemonic.split()))
pages = list(chunks(words, words_per_page))
await paginate(show_mnemonic_page, len(pages), first_page, pages)
@unimport
async def layout_reset_device(message, session_id):
# TODO: Failure if not empty
from trezor.messages.Success import Success
mnemonic = await generate_mnemonic(
message.strength, message.display_random, session_id)
mnemonic = await generate_mnemonic(message.strength,
message.display_random,
session_id)
# await show_mnemonic(mnemonic)
# if m.pin_protection:
# pin = yield from request_new_pin()
# else:
# pin = None
if message.pin_protection:
pin = await request_new_pin(session_id)
else:
pin = None
await show_mnemonic(mnemonic)
return Success()

@ -11,10 +11,10 @@ async def change_page(page, page_count):
return page - 1 # scroll up
async def paginate(render_page, page_count, page=0):
async def paginate(render_page, page_count, page=0, *args):
while True:
changer = change_page(page, page_count)
renderer = render_page(page, page_count)
renderer = render_page(page, page_count, *args)
waiter = loop.Wait([changer, renderer])
result = await waiter
if changer in waiter.finished:
@ -41,7 +41,7 @@ def render_scrollbar(page, page_count):
if page_count * padding > screen_height:
padding = screen_height // page_count
x = 225
x = const(225)
y = (screen_height // 2) - (page_count // 2) * padding
for i in range(0, page_count):

@ -40,8 +40,7 @@ def close_session(session_id):
def register_type(wire_type, genfunc, *args):
if wire_type in _workflow_genfuncs:
raise KeyError('message of type %d already registered' % wire_type)
log.info(__name__, 'registering %s for type %d',
(genfunc, args), wire_type)
log.info(__name__, 'registering message type %d', wire_type)
_workflow_genfuncs[wire_type] = (genfunc, args)
@ -49,8 +48,8 @@ def register_session(session_id, handler):
if session_id not in _opened_sessions:
raise KeyError('session %d is unknown' % session_id)
if session_id in _session_handlers:
raise KeyError('session %d is already registered' % session_id)
log.info(__name__, 'registering %s for session %d', handler, session_id)
raise KeyError('session %d is already being listened on' % session_id)
log.info(__name__, 'listening on session %d', session_id)
_session_handlers[session_id] = handler
@ -78,6 +77,7 @@ def setup():
async def read_message(session_id, *exp_types):
log.info(__name__, 'reading message, one of %s', exp_types)
future = Future()
wire_decoder = decode_wire_stream(
_dispatch_and_build_protobuf, session_id, exp_types, future)
@ -87,6 +87,7 @@ async def read_message(session_id, *exp_types):
async def write_message(session_id, pbuf_message):
log.info(__name__, 'writing message %s', pbuf_message)
msg_data = await pbuf_message.dumps()
msg_type = pbuf_message.message_type.wire_type
writer = write_report_stream()
@ -94,9 +95,56 @@ async def write_message(session_id, pbuf_message):
encode_wire_message(msg_type, msg_data, session_id, writer)
async def reply_message(session_id, pbuf_message, *exp_types):
await write_message(session_id, pbuf_message)
return await read_message(session_id, *exp_types)
class FailureError(Exception):
def __init__(self, code, message):
super(FailureError, self).__init__(code, message)
def to_protobuf(self):
from trezor.messages.Failure import Failure
return Failure(code=self.args[0],
message=self.args[1])
async def monitor_workflow(workflow, session_id):
try:
result = await workflow
except FailureError as e:
await write_message(session_id, e.to_protobuf())
raise
except Exception as e:
from trezor.messages.Failure import Failure
from trezor.messages.FailureType import FirmwareError
await write_message(session_id,
Failure(code=FirmwareError,
message='Firmware Error'))
raise
else:
if result is not None:
await write_message(session_id, result)
return result
finally:
if session_id in _opened_sessions:
wire_decoder = decode_wire_stream(
_handle_registered_type, session_id)
wire_decoder.send(None)
register_session(session_id, wire_decoder)
def protobuf_handler(msg_type, data_len, session_id, callback, *args):
def finalizer(message):
start_workflow(callback(message, session_id, *args))
workflow = callback(message, session_id, *args)
monitored = monitor_workflow(workflow, session_id)
start_workflow(monitored)
pbuf_type = get_protobuf_type(msg_type)
builder = build_protobuf_message(pbuf_type, finalizer)
builder.send(None)
@ -125,10 +173,6 @@ def _handle_unknown_session():
yield # TODO
class UnexpectedMessageError(Exception):
pass
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, future):
if msg_type in exp_types:
pbuf_type = get_protobuf_type(msg_type)
@ -136,18 +180,20 @@ def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, futu
builder.send(None)
return pbuf_type.load(builder)
else:
future.resolve(UnexpectedMessageError(msg_type))
from trezor.messages.FailureType import UnexpectedMessage
future.resolve(FailureError(UnexpectedMessage, 'Unexpected message'))
return _handle_registered_type(msg_type, data_len, session_id)
def _handle_registered_type(msg_type, data_len, session_id):
genfunc, args = _workflow_genfuncs.get(
msg_type, (_handle_unexpected_type, ()))
fallback = (_handle_unexpected_type, ())
genfunc, args = _workflow_genfuncs.get(msg_type, fallback)
return genfunc(msg_type, data_len, session_id, *args)
def _handle_unexpected_type(msg_type, data_len, session_id):
log.info(__name__, 'skipping message %d of len %d' % (msg_type, data_len))
log.info(__name__, 'skipping message %d of len %d on session %d' %
(msg_type, data_len, session_id))
try:
while True:
yield

@ -1,58 +1,37 @@
from trezor import ui
from trezor import wire
from trezor import config
from trezor.utils import unimport
MANAGEMENT_APP = const(1)
PASSPHRASE_PROTECT = (1) # 0 | 1
PIN_PROTECT = const(2) # 0 | 1
PIN = const(4) # str
def prompt_pin(*args, **kwargs):
from trezor.ui.pin import PinMatrix
@unimport
async def request_pin(session_id, *args, **kwargs):
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import ProtectCall
from trezor.messages.FailureType import PinCancelled
from trezor.messages.wire_types import ButtonAck
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
from trezor.ui.pin import PinMatrix
ui.clear()
await wire.reply_message(session_id,
ButtonRequest(code=ProtectCall),
ButtonAck)
ui.clear()
matrix = PinMatrix(*args, **kwargs)
dialog = ConfirmDialog(matrix)
result = yield from dialog.wait()
if await dialog != CONFIRMED:
raise wire.FailureError(PinCancelled, 'PIN cancelled')
return matrix.pin if result == CONFIRMED else None
return matrix.pin
def request_pin(*args, **kwargs):
from trezor.messages.ButtonRequest import ButtonRequest
from trezor.messages.ButtonRequestType import ProtectCall
from trezor.messages.ButtonAck import ButtonAck
ack = yield from wire.call(ButtonRequest(code=ProtectCall), ButtonAck)
pin = yield from prompt_pin(*args, **kwargs)
return pin
def change_pin():
pass
def protect_with_pin():
from trezor.messages.Failure import Failure
@unimport
async def request_new_pin(session_id):
from trezor.messages.FailureType import PinInvalid
from trezor.messages.FailureType import ActionCancelled
pin_protect = config.get(MANAGEMENT_APP, PIN_PROTECT)
if not pin_protect:
return
entered_pin = yield from request_pin()
if entered_pin is None:
yield from wire.write(Failure(code=ActionCancelled, message='Cancelled'))
raise Exception('Cancelled')
pin_first = await request_pin(session_id)
pin_again = await request_pin(session_id, 'Enter PIN again')
if pin_first != pin_again:
raise wire.FailureError(PinInvalid, 'PIN invalid')
stored_pin = config.get(MANAGEMENT_APP, PIN)
if stored_pin != entered_pin:
yield from wire.write(Failure(code=PinInvalid, message='PIN invalid'))
raise Exception('PIN invalid')
return pin_first

Loading…
Cancel
Save