diff --git a/src/apps/common/cache.py b/src/apps/common/cache.py index 01b5b7276d..25bff34d68 100644 --- a/src/apps/common/cache.py +++ b/src/apps/common/cache.py @@ -1,46 +1,57 @@ from trezor.crypto import hashlib, hmac, random from apps.common import storage -memory = {} -_seed = None -_passphrase = None +_cached_seed = None +_cached_passphrase = None -def get_state(state: bytes=None, passphrase: str=None): - - if state is None: +def get_state(prev_state: bytes = None, passphrase: str = None) -> bytes: + if prev_state is None: salt = random.bytes(32) # generate a random salt if no state provided else: - salt = state[:32] # use salt from provided state - + salt = prev_state[:32] # use salt from provided state + if len(salt) != 32: + return None # invalid state if passphrase is None: - global _passphrase - if _passphrase is None: - return None - passphrase = _passphrase # use cached passphrase + if _cached_passphrase is None: + return None # we don't have any passphrase to compute the state + else: + passphrase = _cached_passphrase # use cached passphrase + return _compute_state(salt, passphrase) + +def _compute_state(salt: bytes, passphrase: str) -> bytes: # state = HMAC(passphrase, salt || device_id) - msg = salt + storage.get_device_id().encode() - state = hmac.new(passphrase.encode(), msg, hashlib.sha256).digest() - + message = salt + storage.get_device_id().encode() + state = hmac.new(passphrase.encode(), message, hashlib.sha256).digest() return salt + state def get_seed(): - global _seed - return _seed + return _cached_seed -def set_seed(seed, passphrase): - global _seed, _passphrase - _seed, _passphrase = seed, passphrase +def get_passphrase(): + return _cached_passphrase def has_passphrase(): - global _passphrase - return _passphrase is not None + return _cached_passphrase is not None -def clear(): - global _seed, _passphrase - _seed, _passphrase = None, None +def set_seed(seed): + global _cached_seed + _cached_seed = seed + + +def set_passphrase(passphrase): + global _cached_passphrase + _cached_passphrase = passphrase + + +def clear(skip_passphrase: bool = False): + set_seed(None) + if skip_passphrase: + set_passphrase('') + else: + set_passphrase(None) diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index 218ce84e69..bd9ae9eb78 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -6,27 +6,29 @@ from apps.common.request_passphrase import protect_by_passphrase _DEFAULT_CURVE = 'secp256k1' -async def derive_node(ctx: wire.Context, path=(), curve_name=_DEFAULT_CURVE): - seed = await _get_seed(ctx) +async def derive_node(ctx: wire.Context, path: list, curve_name=_DEFAULT_CURVE): + seed = await _get_cached_seed(ctx) node = bip32.from_seed(seed, curve_name) if path: node.derive_path(path) return node -async def _get_seed(ctx: wire.Context) -> bytes: +async def _get_cached_seed(ctx: wire.Context) -> bytes: + if not storage.is_initialized(): + raise wire.ProcessError('Device is not initialized') if cache.get_seed() is None: - seed, passphrase = await _compute_seed(ctx) - cache.set_seed(seed, passphrase) + passphrase = await _get_cached_passphrase(ctx) + seed = bip39.seed(storage.get_mnemonic(), passphrase) + cache.set_seed(seed) return cache.get_seed() -async def _compute_seed(ctx: wire.Context) -> (bytes, str): - if not storage.is_initialized(): - raise wire.ProcessError('Device is not initialized') - - passphrase = await protect_by_passphrase(ctx) - return bip39.seed(storage.get_mnemonic(), passphrase), passphrase +async def _get_cached_passphrase(ctx: wire.Context) -> str: + if cache.get_passphrase() is None: + passphrase = await protect_by_passphrase(ctx) + cache.set_passphrase(passphrase) + return cache.get_passphrase() def derive_node_without_passphrase(path, curve_name=_DEFAULT_CURVE): diff --git a/src/apps/homescreen/__init__.py b/src/apps/homescreen/__init__.py index 4c71eebf3d..cc93dc63da 100644 --- a/src/apps/homescreen/__init__.py +++ b/src/apps/homescreen/__init__.py @@ -1,5 +1,4 @@ -from trezor import config -from trezor.utils import symbol, model +from trezor import config, utils from trezor.wire import register, protobuf_workflow from trezor.messages import wire_types from trezor.messages.Features import Features @@ -9,58 +8,60 @@ from trezor.messages.Success import Success from apps.common import storage, cache -async def respond_Features(ctx, msg): - - if isinstance(msg, Initialize): - if msg.state is None or bytes(msg.state) != cache.get_state(state=bytes(msg.state)): - cache.clear() - +def get_features(): f = Features() f.vendor = 'trezor.io' - f.major_version = symbol('VERSION_MAJOR') - f.minor_version = symbol('VERSION_MINOR') - f.patch_version = symbol('VERSION_PATCH') - f.device_id = storage.get_device_id() - f.pin_protection = config.has_pin() - f.passphrase_protection = storage.has_passphrase() f.language = 'english' + f.major_version = utils.symbol('VERSION_MAJOR') + f.minor_version = utils.symbol('VERSION_MINOR') + f.patch_version = utils.symbol('VERSION_PATCH') + f.revision = utils.symbol('GITREV') + f.model = utils.model() + if f.model == 'EMU': + f.model = 'T' # emulator currently emulates model T + f.device_id = storage.get_device_id() f.label = storage.get_label() f.initialized = storage.is_initialized() - f.revision = symbol('GITREV') + f.pin_protection = config.has_pin() f.pin_cached = config.has_pin() + f.passphrase_protection = storage.has_passphrase() f.passphrase_cached = cache.has_passphrase() f.needs_backup = storage.needs_backup() - f.flags = storage.get_flags() - if model() in ['T', 'EMU']: # emulator currently emulates model T - f.model = 'T' f.unfinished_backup = storage.unfinished_backup() - + f.flags = storage.get_flags() return f -async def respond_ClearSession(ctx, msg): +async def handle_Initialize(ctx, msg): + if msg.state is None or msg.state != cache.get_state(bytes(msg.state)): + cache.clear(msg.skip_passphrase) + return get_features() + + +async def handle_GetFeatures(ctx, msg): + return get_features() + + +async def handle_ClearSession(ctx, msg): cache.clear() return Success(message='Session cleared') -async def respond_Pong(ctx, msg): - +async def handle_Ping(ctx, msg): if msg.button_protection: from apps.common.confirm import require_confirm from trezor.messages.ButtonRequestType import ProtectCall from trezor.ui.text import Text from trezor import ui await require_confirm(ctx, Text('Confirm', ui.ICON_DEFAULT), ProtectCall) - if msg.passphrase_protection: from apps.common.request_passphrase import protect_by_passphrase await protect_by_passphrase(ctx) - return Success(message=msg.message) def boot(): - register(wire_types.Initialize, protobuf_workflow, respond_Features) - register(wire_types.GetFeatures, protobuf_workflow, respond_Features) - register(wire_types.ClearSession, protobuf_workflow, respond_ClearSession) - register(wire_types.Ping, protobuf_workflow, respond_Pong) + register(wire_types.Initialize, protobuf_workflow, handle_Initialize) + register(wire_types.GetFeatures, protobuf_workflow, handle_GetFeatures) + register(wire_types.ClearSession, protobuf_workflow, handle_ClearSession) + register(wire_types.Ping, protobuf_workflow, handle_Ping)