1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-07 05:51:38 +00:00

apps/homescreen: handle Initialize.skip_passphrase

TODO: tests
This commit is contained in:
Jan Pochyla 2018-05-28 15:20:31 +02:00
parent df17be0287
commit cabb334448
3 changed files with 78 additions and 64 deletions

View File

@ -1,46 +1,57 @@
from trezor.crypto import hashlib, hmac, random from trezor.crypto import hashlib, hmac, random
from apps.common import storage from apps.common import storage
memory = {} _cached_seed = None
_seed = None _cached_passphrase = None
_passphrase = None
def get_state(state: bytes=None, passphrase: str=None): def get_state(prev_state: bytes = None, passphrase: str = None) -> bytes:
if prev_state is None:
if state is None:
salt = random.bytes(32) # generate a random salt if no state provided salt = random.bytes(32) # generate a random salt if no state provided
else: else:
salt = state[:32] # use salt from provided state salt = prev_state[:32] # use salt from provided state
if len(salt) != 32:
return None # invalid state
if passphrase is None: if passphrase is None:
global _passphrase if _cached_passphrase is None:
if _passphrase is None: return None # we don't have any passphrase to compute the state
return None else:
passphrase = _passphrase # use cached passphrase passphrase = _cached_passphrase # use cached passphrase
return _compute_state(salt, passphrase)
def _compute_state(salt: bytes, passphrase: str) -> bytes:
# state = HMAC(passphrase, salt || device_id) # state = HMAC(passphrase, salt || device_id)
msg = salt + storage.get_device_id().encode() message = salt + storage.get_device_id().encode()
state = hmac.new(passphrase.encode(), msg, hashlib.sha256).digest() state = hmac.new(passphrase.encode(), message, hashlib.sha256).digest()
return salt + state return salt + state
def get_seed(): def get_seed():
global _seed return _cached_seed
return _seed
def set_seed(seed, passphrase): def get_passphrase():
global _seed, _passphrase return _cached_passphrase
_seed, _passphrase = seed, passphrase
def has_passphrase(): def has_passphrase():
global _passphrase return _cached_passphrase is not None
return _passphrase is not None
def clear(): def set_seed(seed):
global _seed, _passphrase global _cached_seed
_seed, _passphrase = None, None _cached_seed = seed
def set_passphrase(passphrase):
global _cached_passphrase
_cached_passphrase = passphrase
def clear(skip_passphrase: bool = False):
set_seed(None)
if skip_passphrase:
set_passphrase('')
else:
set_passphrase(None)

View File

@ -6,27 +6,29 @@ from apps.common.request_passphrase import protect_by_passphrase
_DEFAULT_CURVE = 'secp256k1' _DEFAULT_CURVE = 'secp256k1'
async def derive_node(ctx: wire.Context, path=(), curve_name=_DEFAULT_CURVE): async def derive_node(ctx: wire.Context, path: list, curve_name=_DEFAULT_CURVE):
seed = await _get_seed(ctx) seed = await _get_cached_seed(ctx)
node = bip32.from_seed(seed, curve_name) node = bip32.from_seed(seed, curve_name)
if path: if path:
node.derive_path(path) node.derive_path(path)
return node return node
async def _get_seed(ctx: wire.Context) -> bytes: async def _get_cached_seed(ctx: wire.Context) -> bytes:
if not storage.is_initialized():
raise wire.ProcessError('Device is not initialized')
if cache.get_seed() is None: if cache.get_seed() is None:
seed, passphrase = await _compute_seed(ctx) passphrase = await _get_cached_passphrase(ctx)
cache.set_seed(seed, passphrase) seed = bip39.seed(storage.get_mnemonic(), passphrase)
cache.set_seed(seed)
return cache.get_seed() return cache.get_seed()
async def _compute_seed(ctx: wire.Context) -> (bytes, str): async def _get_cached_passphrase(ctx: wire.Context) -> str:
if not storage.is_initialized(): if cache.get_passphrase() is None:
raise wire.ProcessError('Device is not initialized')
passphrase = await protect_by_passphrase(ctx) passphrase = await protect_by_passphrase(ctx)
return bip39.seed(storage.get_mnemonic(), passphrase), passphrase cache.set_passphrase(passphrase)
return cache.get_passphrase()
def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE): def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE):

View File

@ -1,5 +1,4 @@
from trezor import config from trezor import config, utils
from trezor.utils import symbol, model
from trezor.wire import register, protobuf_workflow from trezor.wire import register, protobuf_workflow
from trezor.messages import wire_types from trezor.messages import wire_types
from trezor.messages.Features import Features from trezor.messages.Features import Features
@ -9,58 +8,60 @@ from trezor.messages.Success import Success
from apps.common import storage, cache from apps.common import storage, cache
async def respond_Features(ctx, msg): def get_features():
if isinstance(msg, Initialize):
if msg.state is None or bytes(msg.state) != cache.get_state(state=bytes(msg.state)):
cache.clear()
f = Features() f = Features()
f.vendor = 'trezor.io' f.vendor = 'trezor.io'
f.major_version = symbol('VERSION_MAJOR')
f.minor_version = symbol('VERSION_MINOR')
f.patch_version = symbol('VERSION_PATCH')
f.device_id = storage.get_device_id()
f.pin_protection = config.has_pin()
f.passphrase_protection = storage.has_passphrase()
f.language = 'english' f.language = 'english'
f.major_version = utils.symbol('VERSION_MAJOR')
f.minor_version = utils.symbol('VERSION_MINOR')
f.patch_version = utils.symbol('VERSION_PATCH')
f.revision = utils.symbol('GITREV')
f.model = utils.model()
if f.model == 'EMU':
f.model = 'T' # emulator currently emulates model T
f.device_id = storage.get_device_id()
f.label = storage.get_label() f.label = storage.get_label()
f.initialized = storage.is_initialized() f.initialized = storage.is_initialized()
f.revision = symbol('GITREV') f.pin_protection = config.has_pin()
f.pin_cached = config.has_pin() f.pin_cached = config.has_pin()
f.passphrase_protection = storage.has_passphrase()
f.passphrase_cached = cache.has_passphrase() f.passphrase_cached = cache.has_passphrase()
f.needs_backup = storage.needs_backup() f.needs_backup = storage.needs_backup()
f.flags = storage.get_flags()
if model() in ['T', 'EMU']: # emulator currently emulates model T
f.model = 'T'
f.unfinished_backup = storage.unfinished_backup() f.unfinished_backup = storage.unfinished_backup()
f.flags = storage.get_flags()
return f return f
async def respond_ClearSession(ctx, msg): async def handle_Initialize(ctx, msg):
if msg.state is None or msg.state != cache.get_state(bytes(msg.state)):
cache.clear(msg.skip_passphrase)
return get_features()
async def handle_GetFeatures(ctx, msg):
return get_features()
async def handle_ClearSession(ctx, msg):
cache.clear() cache.clear()
return Success(message='Session cleared') return Success(message='Session cleared')
async def respond_Pong(ctx, msg): async def handle_Ping(ctx, msg):
if msg.button_protection: if msg.button_protection:
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from trezor.messages.ButtonRequestType import ProtectCall from trezor.messages.ButtonRequestType import ProtectCall
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor import ui from trezor import ui
await require_confirm(ctx, Text('Confirm', ui.ICON_DEFAULT), ProtectCall) await require_confirm(ctx, Text('Confirm', ui.ICON_DEFAULT), ProtectCall)
if msg.passphrase_protection: if msg.passphrase_protection:
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
await protect_by_passphrase(ctx) await protect_by_passphrase(ctx)
return Success(message=msg.message) return Success(message=msg.message)
def boot(): def boot():
register(wire_types.Initialize, protobuf_workflow, respond_Features) register(wire_types.Initialize, protobuf_workflow, handle_Initialize)
register(wire_types.GetFeatures, protobuf_workflow, respond_Features) register(wire_types.GetFeatures, protobuf_workflow, handle_GetFeatures)
register(wire_types.ClearSession, protobuf_workflow, respond_ClearSession) register(wire_types.ClearSession, protobuf_workflow, handle_ClearSession)
register(wire_types.Ping, protobuf_workflow, respond_Pong) register(wire_types.Ping, protobuf_workflow, handle_Ping)