1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-08-05 05:15:27 +00:00

apps.common: implement finish device state handling

This commit is contained in:
Pavol Rusnak 2018-02-24 18:58:02 +01:00
parent 35e1135c95
commit 502ecd7bcc
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
4 changed files with 41 additions and 19 deletions

View File

@ -1,11 +1,32 @@
from trezor.crypto import random, hashlib, hmac
from apps.common.storage import get_device_id
memory = {} memory = {}
_seed = None _seed = None
_state = None _passphrase = None
_state_salt = None
def get_state(): def get_state(salt: bytes=None, passphrase: str=None):
global _state global _passphrase, _state_salt
return _state if salt is None:
# generate a random salt if not provided and not already cached
if _state_salt is None:
_state_salt = random.bytes(32)
else:
# otherwise copy provided salt to cached salt
_state_salt = salt
# state = HMAC(passphrase, salt || device_id)
if passphrase is None:
key = _passphrase if _passphrase is not None else ''
else:
key = passphrase
msg = _state_salt + get_device_id().encode()
state = hmac.new(key.encode(), msg, hashlib.sha256).digest()
return _state_salt + state
def get_seed(): def get_seed():
@ -13,15 +34,13 @@ def get_seed():
return _seed return _seed
def set_seed(seed): def set_seed(seed, passphrase):
from trezor.crypto import bip32 global _seed, _passphrase
from trezor.crypto.hashlib import blake2s _seed, _passphrase = seed, _passphrase
node = bip32.from_seed(seed, 'secp256k1')
state = blake2s(node.public_key()).digest()
global _seed, _state
_seed, _state = seed, state
def clear(): def clear():
global _seed, _state global _seed, _passphrase
_seed, _state = None, None global _state_salt
_seed, _passphrase = None, None
_state_salt = None

View File

@ -1,4 +1,5 @@
from trezor import res, ui, wire from trezor import res, ui, wire
from apps.common.cache import get_state
async def request_passphrase(ctx): async def request_passphrase(ctx):
@ -47,7 +48,9 @@ async def request_passphrase(ctx):
raise wire.FailureError(ProcessError, 'Passphrase not provided') raise wire.FailureError(ProcessError, 'Passphrase not provided')
passphrase = ack.passphrase passphrase = ack.passphrase
# TODO: process ack.state and check against the current device state, throw error if different if ack.state is not None:
if ack.state != get_state(salt=ack.state[:32], passphrase=passphrase):
raise wire.FailureError(ProcessError, 'Passphrase mismatch')
return passphrase return passphrase

View File

@ -16,12 +16,12 @@ async def derive_node(ctx: wire.Context, path=[], curve_name=_DEFAULT_CURVE):
async def _get_seed(ctx: wire.Context) -> bytes: async def _get_seed(ctx: wire.Context) -> bytes:
from . import cache from . import cache
if cache.get_seed() is None: if cache.get_seed() is None:
seed = await _compute_seed(ctx) seed, passphrase = await _compute_seed(ctx)
cache.set_seed(seed) cache.set_seed(seed, passphrase)
return cache.get_seed() return cache.get_seed()
async def _compute_seed(ctx: wire.Context) -> bytes: async def _compute_seed(ctx: wire.Context) -> (bytes, str):
from trezor.messages.FailureType import ProcessError from trezor.messages.FailureType import ProcessError
from .request_passphrase import protect_by_passphrase from .request_passphrase import protect_by_passphrase
from . import storage from . import storage
@ -30,7 +30,7 @@ async def _compute_seed(ctx: wire.Context) -> bytes:
raise wire.FailureError(ProcessError, 'Device is not initialized') raise wire.FailureError(ProcessError, 'Device is not initialized')
passphrase = await protect_by_passphrase(ctx) passphrase = await protect_by_passphrase(ctx)
return bip39.seed(storage.get_mnemonic(), passphrase) return bip39.seed(storage.get_mnemonic(), passphrase), passphrase
def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE): def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE):

View File

@ -10,7 +10,7 @@ async def respond_Features(ctx, msg):
from trezor.messages.Features import Features from trezor.messages.Features import Features
if msg.__qualname__ == 'Initialize': if msg.__qualname__ == 'Initialize':
if msg.state is None or msg.state != cache.get_state(): if msg.state is None or msg.state != cache.get_state(salt=msg.state[:32]):
cache.clear() cache.clear()
f = Features() f = Features()