mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-04 03:40:58 +00:00
workflow monitor, reset_device layout work
- request/response manner of usual protobuf workflows is enforced, workflows are expected to either return a valid protobuf response, or raise an exception - added wire.FailureError exception that allows workflow to provide Failure code & message - pin workflows simplified TODO: all this workflow work does not really belong in trezor.wire
This commit is contained in:
parent
70110187cc
commit
34ed2fb86a
@ -1,4 +1,5 @@
|
|||||||
from trezor import wire, ui
|
from trezor import wire, ui
|
||||||
|
from trezor.workflows.request_pin import request_new_pin
|
||||||
from trezor.messages.wire_types import EntropyAck
|
from trezor.messages.wire_types import EntropyAck
|
||||||
from trezor.ui.button import Button, CONFIRM_BUTTON, CONFIRM_BUTTON_ACTIVE
|
from trezor.ui.button import Button, CONFIRM_BUTTON, CONFIRM_BUTTON_ACTIVE
|
||||||
from trezor.ui.scroll import paginate, render_scrollbar, animate_swipe
|
from trezor.ui.scroll import paginate, render_scrollbar, animate_swipe
|
||||||
@ -6,75 +7,72 @@ from trezor.crypto import hashlib, random, bip39
|
|||||||
from trezor.utils import unimport, chunks
|
from trezor.utils import unimport, chunks
|
||||||
|
|
||||||
|
|
||||||
|
@unimport
|
||||||
async def generate_mnemonic(strength, display_random, session_id):
|
async def generate_mnemonic(strength, display_random, session_id):
|
||||||
from trezor.messages.EntropyRequest import EntropyRequest
|
from trezor.messages.EntropyRequest import EntropyRequest
|
||||||
|
from trezor.messages.FailureType import Other
|
||||||
|
|
||||||
await wire.write_message(session_id, EntropyRequest())
|
if strength not in (128, 192, 256):
|
||||||
ack = await wire.read_message(session_id, EntropyAck)
|
raise wire.FailureError(Other, 'Invalid seed strength')
|
||||||
|
|
||||||
|
# if display_random:
|
||||||
|
# raise wire.FailureError(Other, 'Entropy display not implemented')
|
||||||
|
|
||||||
|
ack = await wire.reply_message(session_id,
|
||||||
|
EntropyRequest(),
|
||||||
|
EntropyAck)
|
||||||
|
|
||||||
|
strength_bytes = strength // 8
|
||||||
ctx = hashlib.sha256()
|
ctx = hashlib.sha256()
|
||||||
ctx.update(random.bytes(32))
|
ctx.update(random.bytes(strength_bytes))
|
||||||
ctx.update(ack.entropy)
|
ctx.update(ack.entropy[:strength_bytes])
|
||||||
entropy = ctx.digest()
|
entropy = ctx.digest()
|
||||||
|
|
||||||
# TODO: handle strength
|
|
||||||
# TODO: handle display_random
|
|
||||||
|
|
||||||
return bip39.from_data(entropy)
|
return bip39.from_data(entropy)
|
||||||
|
|
||||||
|
|
||||||
async def request_new_pin():
|
async def show_mnemonic_page(page, page_count, mnemonic):
|
||||||
from trezor.workflows.request_pin import request_pin
|
ui.clear()
|
||||||
|
ui.display.text(10, 30, 'Write down your seed',
|
||||||
|
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
|
||||||
|
render_scrollbar(page, page_count)
|
||||||
|
|
||||||
pin = await request_pin()
|
for pi, (wi, word) in enumerate(mnemonic[page]):
|
||||||
pin_again = await request_pin('Enter PIN again')
|
top = pi * 30 + 74
|
||||||
|
pos = wi + 1
|
||||||
|
ui.display.text_right(40, top, '%d.' % pos,
|
||||||
|
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
|
||||||
|
ui.display.text(45, top, '%s' % word,
|
||||||
|
ui.BOLD, ui.WHITE, ui.BLACK)
|
||||||
|
|
||||||
if pin == pin_again:
|
if page + 1 == page_count:
|
||||||
return pin
|
await Button((0, 240 - 48, 240, 48), 'Finish',
|
||||||
|
normal_style=CONFIRM_BUTTON,
|
||||||
|
active_style=CONFIRM_BUTTON_ACTIVE)
|
||||||
else:
|
else:
|
||||||
raise Exception() # TODO: wrong PIN should be handled in unified way
|
await animate_swipe()
|
||||||
|
|
||||||
|
|
||||||
async def show_mnemonic(mnemonic):
|
async def show_mnemonic(mnemonic):
|
||||||
|
first_page = const(0)
|
||||||
words_per_page = const(4)
|
words_per_page = const(4)
|
||||||
mnemonic_words = list(enumerate(mnemonic.split()))
|
words = list(enumerate(mnemonic.split()))
|
||||||
mnemonic_pages = list(chunks(mnemonic_words, words_per_page))
|
pages = list(chunks(words, words_per_page))
|
||||||
|
await paginate(show_mnemonic_page, len(pages), first_page, pages)
|
||||||
async def render(page, page_count):
|
|
||||||
|
|
||||||
# render header & scrollbar
|
|
||||||
ui.clear()
|
|
||||||
ui.display.text(10, 30, 'Write down your seed',
|
|
||||||
ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
|
|
||||||
render_scrollbar(page, page_count)
|
|
||||||
|
|
||||||
# render mnemonic page
|
|
||||||
for pi, (wi, word) in enumerate(mnemonic_pages[page]):
|
|
||||||
top = pi * 30 + 74
|
|
||||||
pos = wi + 1
|
|
||||||
ui.display.text_right(40, top, '%d.' %
|
|
||||||
pos, ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)
|
|
||||||
ui.display.text(45, top, '%s' % word, ui.BOLD, ui.WHITE, ui.BLACK)
|
|
||||||
|
|
||||||
if page + 1 == page_count:
|
|
||||||
await Button((0, 240 - 48, 240, 48), 'Finish',
|
|
||||||
normal_style=CONFIRM_BUTTON,
|
|
||||||
active_style=CONFIRM_BUTTON_ACTIVE)
|
|
||||||
else:
|
|
||||||
await animate_swipe()
|
|
||||||
|
|
||||||
await paginate(render, len(mnemonic_pages))
|
|
||||||
|
|
||||||
|
|
||||||
|
@unimport
|
||||||
async def layout_reset_device(message, session_id):
|
async def layout_reset_device(message, session_id):
|
||||||
# TODO: Failure if not empty
|
from trezor.messages.Success import Success
|
||||||
|
|
||||||
mnemonic = await generate_mnemonic(
|
mnemonic = await generate_mnemonic(message.strength,
|
||||||
message.strength, message.display_random, session_id)
|
message.display_random,
|
||||||
|
session_id)
|
||||||
|
# await show_mnemonic(mnemonic)
|
||||||
|
|
||||||
# if m.pin_protection:
|
if message.pin_protection:
|
||||||
# pin = yield from request_new_pin()
|
pin = await request_new_pin(session_id)
|
||||||
# else:
|
else:
|
||||||
# pin = None
|
pin = None
|
||||||
|
|
||||||
await show_mnemonic(mnemonic)
|
return Success()
|
||||||
|
@ -11,10 +11,10 @@ async def change_page(page, page_count):
|
|||||||
return page - 1 # scroll up
|
return page - 1 # scroll up
|
||||||
|
|
||||||
|
|
||||||
async def paginate(render_page, page_count, page=0):
|
async def paginate(render_page, page_count, page=0, *args):
|
||||||
while True:
|
while True:
|
||||||
changer = change_page(page, page_count)
|
changer = change_page(page, page_count)
|
||||||
renderer = render_page(page, page_count)
|
renderer = render_page(page, page_count, *args)
|
||||||
waiter = loop.Wait([changer, renderer])
|
waiter = loop.Wait([changer, renderer])
|
||||||
result = await waiter
|
result = await waiter
|
||||||
if changer in waiter.finished:
|
if changer in waiter.finished:
|
||||||
@ -41,7 +41,7 @@ def render_scrollbar(page, page_count):
|
|||||||
if page_count * padding > screen_height:
|
if page_count * padding > screen_height:
|
||||||
padding = screen_height // page_count
|
padding = screen_height // page_count
|
||||||
|
|
||||||
x = 225
|
x = const(225)
|
||||||
y = (screen_height // 2) - (page_count // 2) * padding
|
y = (screen_height // 2) - (page_count // 2) * padding
|
||||||
|
|
||||||
for i in range(0, page_count):
|
for i in range(0, page_count):
|
||||||
|
@ -40,8 +40,7 @@ def close_session(session_id):
|
|||||||
def register_type(wire_type, genfunc, *args):
|
def register_type(wire_type, genfunc, *args):
|
||||||
if wire_type in _workflow_genfuncs:
|
if wire_type in _workflow_genfuncs:
|
||||||
raise KeyError('message of type %d already registered' % wire_type)
|
raise KeyError('message of type %d already registered' % wire_type)
|
||||||
log.info(__name__, 'registering %s for type %d',
|
log.info(__name__, 'registering message type %d', wire_type)
|
||||||
(genfunc, args), wire_type)
|
|
||||||
_workflow_genfuncs[wire_type] = (genfunc, args)
|
_workflow_genfuncs[wire_type] = (genfunc, args)
|
||||||
|
|
||||||
|
|
||||||
@ -49,8 +48,8 @@ def register_session(session_id, handler):
|
|||||||
if session_id not in _opened_sessions:
|
if session_id not in _opened_sessions:
|
||||||
raise KeyError('session %d is unknown' % session_id)
|
raise KeyError('session %d is unknown' % session_id)
|
||||||
if session_id in _session_handlers:
|
if session_id in _session_handlers:
|
||||||
raise KeyError('session %d is already registered' % session_id)
|
raise KeyError('session %d is already being listened on' % session_id)
|
||||||
log.info(__name__, 'registering %s for session %d', handler, session_id)
|
log.info(__name__, 'listening on session %d', session_id)
|
||||||
_session_handlers[session_id] = handler
|
_session_handlers[session_id] = handler
|
||||||
|
|
||||||
|
|
||||||
@ -78,6 +77,7 @@ def setup():
|
|||||||
|
|
||||||
|
|
||||||
async def read_message(session_id, *exp_types):
|
async def read_message(session_id, *exp_types):
|
||||||
|
log.info(__name__, 'reading message, one of %s', exp_types)
|
||||||
future = Future()
|
future = Future()
|
||||||
wire_decoder = decode_wire_stream(
|
wire_decoder = decode_wire_stream(
|
||||||
_dispatch_and_build_protobuf, session_id, exp_types, future)
|
_dispatch_and_build_protobuf, session_id, exp_types, future)
|
||||||
@ -87,6 +87,7 @@ async def read_message(session_id, *exp_types):
|
|||||||
|
|
||||||
|
|
||||||
async def write_message(session_id, pbuf_message):
|
async def write_message(session_id, pbuf_message):
|
||||||
|
log.info(__name__, 'writing message %s', pbuf_message)
|
||||||
msg_data = await pbuf_message.dumps()
|
msg_data = await pbuf_message.dumps()
|
||||||
msg_type = pbuf_message.message_type.wire_type
|
msg_type = pbuf_message.message_type.wire_type
|
||||||
writer = write_report_stream()
|
writer = write_report_stream()
|
||||||
@ -94,9 +95,56 @@ async def write_message(session_id, pbuf_message):
|
|||||||
encode_wire_message(msg_type, msg_data, session_id, writer)
|
encode_wire_message(msg_type, msg_data, session_id, writer)
|
||||||
|
|
||||||
|
|
||||||
|
async def reply_message(session_id, pbuf_message, *exp_types):
|
||||||
|
await write_message(session_id, pbuf_message)
|
||||||
|
return await read_message(session_id, *exp_types)
|
||||||
|
|
||||||
|
|
||||||
|
class FailureError(Exception):
|
||||||
|
|
||||||
|
def __init__(self, code, message):
|
||||||
|
super(FailureError, self).__init__(code, message)
|
||||||
|
|
||||||
|
def to_protobuf(self):
|
||||||
|
from trezor.messages.Failure import Failure
|
||||||
|
return Failure(code=self.args[0],
|
||||||
|
message=self.args[1])
|
||||||
|
|
||||||
|
|
||||||
|
async def monitor_workflow(workflow, session_id):
|
||||||
|
try:
|
||||||
|
result = await workflow
|
||||||
|
|
||||||
|
except FailureError as e:
|
||||||
|
await write_message(session_id, e.to_protobuf())
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
from trezor.messages.Failure import Failure
|
||||||
|
from trezor.messages.FailureType import FirmwareError
|
||||||
|
await write_message(session_id,
|
||||||
|
Failure(code=FirmwareError,
|
||||||
|
message='Firmware Error'))
|
||||||
|
raise
|
||||||
|
|
||||||
|
else:
|
||||||
|
if result is not None:
|
||||||
|
await write_message(session_id, result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if session_id in _opened_sessions:
|
||||||
|
wire_decoder = decode_wire_stream(
|
||||||
|
_handle_registered_type, session_id)
|
||||||
|
wire_decoder.send(None)
|
||||||
|
register_session(session_id, wire_decoder)
|
||||||
|
|
||||||
|
|
||||||
def protobuf_handler(msg_type, data_len, session_id, callback, *args):
|
def protobuf_handler(msg_type, data_len, session_id, callback, *args):
|
||||||
def finalizer(message):
|
def finalizer(message):
|
||||||
start_workflow(callback(message, session_id, *args))
|
workflow = callback(message, session_id, *args)
|
||||||
|
monitored = monitor_workflow(workflow, session_id)
|
||||||
|
start_workflow(monitored)
|
||||||
pbuf_type = get_protobuf_type(msg_type)
|
pbuf_type = get_protobuf_type(msg_type)
|
||||||
builder = build_protobuf_message(pbuf_type, finalizer)
|
builder = build_protobuf_message(pbuf_type, finalizer)
|
||||||
builder.send(None)
|
builder.send(None)
|
||||||
@ -125,10 +173,6 @@ def _handle_unknown_session():
|
|||||||
yield # TODO
|
yield # TODO
|
||||||
|
|
||||||
|
|
||||||
class UnexpectedMessageError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, future):
|
def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, future):
|
||||||
if msg_type in exp_types:
|
if msg_type in exp_types:
|
||||||
pbuf_type = get_protobuf_type(msg_type)
|
pbuf_type = get_protobuf_type(msg_type)
|
||||||
@ -136,18 +180,20 @@ def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, futu
|
|||||||
builder.send(None)
|
builder.send(None)
|
||||||
return pbuf_type.load(builder)
|
return pbuf_type.load(builder)
|
||||||
else:
|
else:
|
||||||
future.resolve(UnexpectedMessageError(msg_type))
|
from trezor.messages.FailureType import UnexpectedMessage
|
||||||
|
future.resolve(FailureError(UnexpectedMessage, 'Unexpected message'))
|
||||||
return _handle_registered_type(msg_type, data_len, session_id)
|
return _handle_registered_type(msg_type, data_len, session_id)
|
||||||
|
|
||||||
|
|
||||||
def _handle_registered_type(msg_type, data_len, session_id):
|
def _handle_registered_type(msg_type, data_len, session_id):
|
||||||
genfunc, args = _workflow_genfuncs.get(
|
fallback = (_handle_unexpected_type, ())
|
||||||
msg_type, (_handle_unexpected_type, ()))
|
genfunc, args = _workflow_genfuncs.get(msg_type, fallback)
|
||||||
return genfunc(msg_type, data_len, session_id, *args)
|
return genfunc(msg_type, data_len, session_id, *args)
|
||||||
|
|
||||||
|
|
||||||
def _handle_unexpected_type(msg_type, data_len, session_id):
|
def _handle_unexpected_type(msg_type, data_len, session_id):
|
||||||
log.info(__name__, 'skipping message %d of len %d' % (msg_type, data_len))
|
log.info(__name__, 'skipping message %d of len %d on session %d' %
|
||||||
|
(msg_type, data_len, session_id))
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield
|
yield
|
||||||
|
@ -1,58 +1,37 @@
|
|||||||
from trezor import ui
|
from trezor import ui
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor import config
|
|
||||||
from trezor.utils import unimport
|
from trezor.utils import unimport
|
||||||
|
|
||||||
MANAGEMENT_APP = const(1)
|
|
||||||
|
|
||||||
PASSPHRASE_PROTECT = (1) # 0 | 1
|
@unimport
|
||||||
PIN_PROTECT = const(2) # 0 | 1
|
async def request_pin(session_id, *args, **kwargs):
|
||||||
PIN = const(4) # str
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_pin(*args, **kwargs):
|
|
||||||
from trezor.ui.pin import PinMatrix
|
|
||||||
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
|
|
||||||
|
|
||||||
ui.clear()
|
|
||||||
|
|
||||||
matrix = PinMatrix(*args, **kwargs)
|
|
||||||
dialog = ConfirmDialog(matrix)
|
|
||||||
result = yield from dialog.wait()
|
|
||||||
|
|
||||||
return matrix.pin if result == CONFIRMED else None
|
|
||||||
|
|
||||||
|
|
||||||
def request_pin(*args, **kwargs):
|
|
||||||
from trezor.messages.ButtonRequest import ButtonRequest
|
from trezor.messages.ButtonRequest import ButtonRequest
|
||||||
from trezor.messages.ButtonRequestType import ProtectCall
|
from trezor.messages.ButtonRequestType import ProtectCall
|
||||||
from trezor.messages.ButtonAck import ButtonAck
|
from trezor.messages.FailureType import PinCancelled
|
||||||
|
from trezor.messages.wire_types import ButtonAck
|
||||||
|
from trezor.ui.confirm import ConfirmDialog, CONFIRMED
|
||||||
|
from trezor.ui.pin import PinMatrix
|
||||||
|
|
||||||
ack = yield from wire.call(ButtonRequest(code=ProtectCall), ButtonAck)
|
await wire.reply_message(session_id,
|
||||||
pin = yield from prompt_pin(*args, **kwargs)
|
ButtonRequest(code=ProtectCall),
|
||||||
|
ButtonAck)
|
||||||
|
|
||||||
return pin
|
ui.clear()
|
||||||
|
matrix = PinMatrix(*args, **kwargs)
|
||||||
|
dialog = ConfirmDialog(matrix)
|
||||||
|
if await dialog != CONFIRMED:
|
||||||
|
raise wire.FailureError(PinCancelled, 'PIN cancelled')
|
||||||
|
|
||||||
|
return matrix.pin
|
||||||
|
|
||||||
|
|
||||||
def change_pin():
|
@unimport
|
||||||
pass
|
async def request_new_pin(session_id):
|
||||||
|
|
||||||
|
|
||||||
def protect_with_pin():
|
|
||||||
from trezor.messages.Failure import Failure
|
|
||||||
from trezor.messages.FailureType import PinInvalid
|
from trezor.messages.FailureType import PinInvalid
|
||||||
from trezor.messages.FailureType import ActionCancelled
|
|
||||||
|
|
||||||
pin_protect = config.get(MANAGEMENT_APP, PIN_PROTECT)
|
pin_first = await request_pin(session_id)
|
||||||
if not pin_protect:
|
pin_again = await request_pin(session_id, 'Enter PIN again')
|
||||||
return
|
if pin_first != pin_again:
|
||||||
|
raise wire.FailureError(PinInvalid, 'PIN invalid')
|
||||||
|
|
||||||
entered_pin = yield from request_pin()
|
return pin_first
|
||||||
if entered_pin is None:
|
|
||||||
yield from wire.write(Failure(code=ActionCancelled, message='Cancelled'))
|
|
||||||
raise Exception('Cancelled')
|
|
||||||
|
|
||||||
stored_pin = config.get(MANAGEMENT_APP, PIN)
|
|
||||||
if stored_pin != entered_pin:
|
|
||||||
yield from wire.write(Failure(code=PinInvalid, message='PIN invalid'))
|
|
||||||
raise Exception('PIN invalid')
|
|
||||||
|
Loading…
Reference in New Issue
Block a user