mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-23 05:40:57 +00:00
apps/homescreen: handle Initialize.skip_passphrase
TODO: tests
This commit is contained in:
parent
df17be0287
commit
cabb334448
@ -1,46 +1,57 @@
|
|||||||
from trezor.crypto import hashlib, hmac, random
|
from trezor.crypto import hashlib, hmac, random
|
||||||
from apps.common import storage
|
from apps.common import storage
|
||||||
|
|
||||||
memory = {}
|
_cached_seed = None
|
||||||
_seed = None
|
_cached_passphrase = None
|
||||||
_passphrase = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_state(state: bytes=None, passphrase: str=None):
|
def get_state(prev_state: bytes = None, passphrase: str = None) -> bytes:
|
||||||
|
if prev_state is None:
|
||||||
if state is None:
|
|
||||||
salt = random.bytes(32) # generate a random salt if no state provided
|
salt = random.bytes(32) # generate a random salt if no state provided
|
||||||
else:
|
else:
|
||||||
salt = state[:32] # use salt from provided state
|
salt = prev_state[:32] # use salt from provided state
|
||||||
|
if len(salt) != 32:
|
||||||
|
return None # invalid state
|
||||||
if passphrase is None:
|
if passphrase is None:
|
||||||
global _passphrase
|
if _cached_passphrase is None:
|
||||||
if _passphrase is None:
|
return None # we don't have any passphrase to compute the state
|
||||||
return None
|
else:
|
||||||
passphrase = _passphrase # use cached passphrase
|
passphrase = _cached_passphrase # use cached passphrase
|
||||||
|
return _compute_state(salt, passphrase)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_state(salt: bytes, passphrase: str) -> bytes:
|
||||||
# state = HMAC(passphrase, salt || device_id)
|
# state = HMAC(passphrase, salt || device_id)
|
||||||
msg = salt + storage.get_device_id().encode()
|
message = salt + storage.get_device_id().encode()
|
||||||
state = hmac.new(passphrase.encode(), msg, hashlib.sha256).digest()
|
state = hmac.new(passphrase.encode(), message, hashlib.sha256).digest()
|
||||||
|
|
||||||
return salt + state
|
return salt + state
|
||||||
|
|
||||||
|
|
||||||
def get_seed():
|
def get_seed():
|
||||||
global _seed
|
return _cached_seed
|
||||||
return _seed
|
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed, passphrase):
|
def get_passphrase():
|
||||||
global _seed, _passphrase
|
return _cached_passphrase
|
||||||
_seed, _passphrase = seed, passphrase
|
|
||||||
|
|
||||||
|
|
||||||
def has_passphrase():
|
def has_passphrase():
|
||||||
global _passphrase
|
return _cached_passphrase is not None
|
||||||
return _passphrase is not None
|
|
||||||
|
|
||||||
|
|
||||||
def clear():
|
def set_seed(seed):
|
||||||
global _seed, _passphrase
|
global _cached_seed
|
||||||
_seed, _passphrase = None, None
|
_cached_seed = seed
|
||||||
|
|
||||||
|
|
||||||
|
def set_passphrase(passphrase):
|
||||||
|
global _cached_passphrase
|
||||||
|
_cached_passphrase = passphrase
|
||||||
|
|
||||||
|
|
||||||
|
def clear(skip_passphrase: bool = False):
|
||||||
|
set_seed(None)
|
||||||
|
if skip_passphrase:
|
||||||
|
set_passphrase('')
|
||||||
|
else:
|
||||||
|
set_passphrase(None)
|
||||||
|
@ -6,27 +6,29 @@ from apps.common.request_passphrase import protect_by_passphrase
|
|||||||
_DEFAULT_CURVE = 'secp256k1'
|
_DEFAULT_CURVE = 'secp256k1'
|
||||||
|
|
||||||
|
|
||||||
async def derive_node(ctx: wire.Context, path=(), curve_name=_DEFAULT_CURVE):
|
async def derive_node(ctx: wire.Context, path: list, curve_name=_DEFAULT_CURVE):
|
||||||
seed = await _get_seed(ctx)
|
seed = await _get_cached_seed(ctx)
|
||||||
node = bip32.from_seed(seed, curve_name)
|
node = bip32.from_seed(seed, curve_name)
|
||||||
if path:
|
if path:
|
||||||
node.derive_path(path)
|
node.derive_path(path)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
async def _get_seed(ctx: wire.Context) -> bytes:
|
async def _get_cached_seed(ctx: wire.Context) -> bytes:
|
||||||
|
if not storage.is_initialized():
|
||||||
|
raise wire.ProcessError('Device is not initialized')
|
||||||
if cache.get_seed() is None:
|
if cache.get_seed() is None:
|
||||||
seed, passphrase = await _compute_seed(ctx)
|
passphrase = await _get_cached_passphrase(ctx)
|
||||||
cache.set_seed(seed, passphrase)
|
seed = bip39.seed(storage.get_mnemonic(), passphrase)
|
||||||
|
cache.set_seed(seed)
|
||||||
return cache.get_seed()
|
return cache.get_seed()
|
||||||
|
|
||||||
|
|
||||||
async def _compute_seed(ctx: wire.Context) -> (bytes, str):
|
async def _get_cached_passphrase(ctx: wire.Context) -> str:
|
||||||
if not storage.is_initialized():
|
if cache.get_passphrase() is None:
|
||||||
raise wire.ProcessError('Device is not initialized')
|
passphrase = await protect_by_passphrase(ctx)
|
||||||
|
cache.set_passphrase(passphrase)
|
||||||
passphrase = await protect_by_passphrase(ctx)
|
return cache.get_passphrase()
|
||||||
return bip39.seed(storage.get_mnemonic(), passphrase), passphrase
|
|
||||||
|
|
||||||
|
|
||||||
def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE):
|
def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE):
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from trezor import config
|
from trezor import config, utils
|
||||||
from trezor.utils import symbol, model
|
|
||||||
from trezor.wire import register, protobuf_workflow
|
from trezor.wire import register, protobuf_workflow
|
||||||
from trezor.messages import wire_types
|
from trezor.messages import wire_types
|
||||||
from trezor.messages.Features import Features
|
from trezor.messages.Features import Features
|
||||||
@ -9,58 +8,60 @@ from trezor.messages.Success import Success
|
|||||||
from apps.common import storage, cache
|
from apps.common import storage, cache
|
||||||
|
|
||||||
|
|
||||||
async def respond_Features(ctx, msg):
|
def get_features():
|
||||||
|
|
||||||
if isinstance(msg, Initialize):
|
|
||||||
if msg.state is None or bytes(msg.state) != cache.get_state(state=bytes(msg.state)):
|
|
||||||
cache.clear()
|
|
||||||
|
|
||||||
f = Features()
|
f = Features()
|
||||||
f.vendor = 'trezor.io'
|
f.vendor = 'trezor.io'
|
||||||
f.major_version = symbol('VERSION_MAJOR')
|
|
||||||
f.minor_version = symbol('VERSION_MINOR')
|
|
||||||
f.patch_version = symbol('VERSION_PATCH')
|
|
||||||
f.device_id = storage.get_device_id()
|
|
||||||
f.pin_protection = config.has_pin()
|
|
||||||
f.passphrase_protection = storage.has_passphrase()
|
|
||||||
f.language = 'english'
|
f.language = 'english'
|
||||||
|
f.major_version = utils.symbol('VERSION_MAJOR')
|
||||||
|
f.minor_version = utils.symbol('VERSION_MINOR')
|
||||||
|
f.patch_version = utils.symbol('VERSION_PATCH')
|
||||||
|
f.revision = utils.symbol('GITREV')
|
||||||
|
f.model = utils.model()
|
||||||
|
if f.model == 'EMU':
|
||||||
|
f.model = 'T' # emulator currently emulates model T
|
||||||
|
f.device_id = storage.get_device_id()
|
||||||
f.label = storage.get_label()
|
f.label = storage.get_label()
|
||||||
f.initialized = storage.is_initialized()
|
f.initialized = storage.is_initialized()
|
||||||
f.revision = symbol('GITREV')
|
f.pin_protection = config.has_pin()
|
||||||
f.pin_cached = config.has_pin()
|
f.pin_cached = config.has_pin()
|
||||||
|
f.passphrase_protection = storage.has_passphrase()
|
||||||
f.passphrase_cached = cache.has_passphrase()
|
f.passphrase_cached = cache.has_passphrase()
|
||||||
f.needs_backup = storage.needs_backup()
|
f.needs_backup = storage.needs_backup()
|
||||||
f.flags = storage.get_flags()
|
|
||||||
if model() in ['T', 'EMU']: # emulator currently emulates model T
|
|
||||||
f.model = 'T'
|
|
||||||
f.unfinished_backup = storage.unfinished_backup()
|
f.unfinished_backup = storage.unfinished_backup()
|
||||||
|
f.flags = storage.get_flags()
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
async def respond_ClearSession(ctx, msg):
|
async def handle_Initialize(ctx, msg):
|
||||||
|
if msg.state is None or msg.state != cache.get_state(bytes(msg.state)):
|
||||||
|
cache.clear(msg.skip_passphrase)
|
||||||
|
return get_features()
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_GetFeatures(ctx, msg):
|
||||||
|
return get_features()
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_ClearSession(ctx, msg):
|
||||||
cache.clear()
|
cache.clear()
|
||||||
return Success(message='Session cleared')
|
return Success(message='Session cleared')
|
||||||
|
|
||||||
|
|
||||||
async def respond_Pong(ctx, msg):
|
async def handle_Ping(ctx, msg):
|
||||||
|
|
||||||
if msg.button_protection:
|
if msg.button_protection:
|
||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
from trezor.messages.ButtonRequestType import ProtectCall
|
from trezor.messages.ButtonRequestType import ProtectCall
|
||||||
from trezor.ui.text import Text
|
from trezor.ui.text import Text
|
||||||
from trezor import ui
|
from trezor import ui
|
||||||
await require_confirm(ctx, Text('Confirm', ui.ICON_DEFAULT), ProtectCall)
|
await require_confirm(ctx, Text('Confirm', ui.ICON_DEFAULT), ProtectCall)
|
||||||
|
|
||||||
if msg.passphrase_protection:
|
if msg.passphrase_protection:
|
||||||
from apps.common.request_passphrase import protect_by_passphrase
|
from apps.common.request_passphrase import protect_by_passphrase
|
||||||
await protect_by_passphrase(ctx)
|
await protect_by_passphrase(ctx)
|
||||||
|
|
||||||
return Success(message=msg.message)
|
return Success(message=msg.message)
|
||||||
|
|
||||||
|
|
||||||
def boot():
|
def boot():
|
||||||
register(wire_types.Initialize, protobuf_workflow, respond_Features)
|
register(wire_types.Initialize, protobuf_workflow, handle_Initialize)
|
||||||
register(wire_types.GetFeatures, protobuf_workflow, respond_Features)
|
register(wire_types.GetFeatures, protobuf_workflow, handle_GetFeatures)
|
||||||
register(wire_types.ClearSession, protobuf_workflow, respond_ClearSession)
|
register(wire_types.ClearSession, protobuf_workflow, handle_ClearSession)
|
||||||
register(wire_types.Ping, protobuf_workflow, respond_Pong)
|
register(wire_types.Ping, protobuf_workflow, handle_Ping)
|
||||||
|
Loading…
Reference in New Issue
Block a user