core: create top-level storage module

This is to avoid including app-specific functionality in storage and
avoid circular imports. The following policy is now in effect: modules
from `storage` namespace must not import from `apps` namespace.

In most files, the change only involves changing import paths.

A minor refactor was needed in case of webauthn: basic get/set/delete
functionality was left in storage.webauthn, and more advanced logic on
top of it was moved to apps.webauthn.resident_credentials.

A significant refactor was needed for sd_salt, where application (and
UI) logic was tightly coupled with the IO code. This is now separated,
and storage.sd_salt deals exclusively with the IO side, while the app/UI
logic is implemented on top of it in apps.common.sd_salt and
apps.management.sd_protect.
pull/665/head
matejcik 5 years ago
parent 39a532c8b1
commit 5c93ecd53a

@ -1,8 +1,10 @@
import storage
import storage.cache
from trezor import wire from trezor import wire
from trezor.crypto import bip32 from trezor.crypto import bip32
from apps.cardano import CURVE, SEED_NAMESPACE from apps.cardano import CURVE, SEED_NAMESPACE
from apps.common import cache, mnemonic, storage from apps.common import mnemonic
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
@ -29,10 +31,10 @@ class Keychain:
async def _get_passphrase(ctx: wire.Context) -> bytes: async def _get_passphrase(ctx: wire.Context) -> bytes:
passphrase = cache.get_passphrase() passphrase = storage.cache.get_passphrase()
if passphrase is None: if passphrase is None:
passphrase = await protect_by_passphrase(ctx) passphrase = await protect_by_passphrase(ctx)
cache.set_passphrase(passphrase) storage.cache.set_passphrase(passphrase)
return passphrase return passphrase
@ -46,11 +48,11 @@ async def get_keychain(ctx: wire.Context) -> Keychain:
passphrase = await _get_passphrase(ctx) passphrase = await _get_passphrase(ctx)
root = bip32.from_mnemonic_cardano(mnemonic.get_secret().decode(), passphrase) root = bip32.from_mnemonic_cardano(mnemonic.get_secret().decode(), passphrase)
else: else:
seed = cache.get_seed() seed = storage.cache.get_seed()
if seed is None: if seed is None:
passphrase = await _get_passphrase(ctx) passphrase = await _get_passphrase(ctx)
seed = mnemonic.get_seed(passphrase) seed = mnemonic.get_seed(passphrase)
cache.set_seed(seed) storage.cache.set_seed(seed)
root = bip32.from_seed(seed, "ed25519 cardano seed") root = bip32.from_seed(seed, "ed25519 cardano seed")
# derive the namespaced root node # derive the namespaced root node

@ -1,9 +1,8 @@
import storage.device
from trezor import ui, workflow from trezor import ui, workflow
from trezor.crypto import bip39, slip39 from trezor.crypto import bip39, slip39
from trezor.messages import BackupType from trezor.messages import BackupType
from apps.common.storage import device as storage_device
if False: if False:
from typing import Optional, Tuple from typing import Optional, Tuple
from trezor.messages.ResetDevice import EnumTypeBackupType from trezor.messages.ResetDevice import EnumTypeBackupType
@ -14,11 +13,11 @@ def get() -> Tuple[Optional[bytes], int]:
def get_secret() -> Optional[bytes]: def get_secret() -> Optional[bytes]:
return storage_device.get_mnemonic_secret() return storage.device.get_mnemonic_secret()
def get_type() -> EnumTypeBackupType: def get_type() -> EnumTypeBackupType:
return storage_device.get_backup_type() return storage.device.get_backup_type()
def is_bip39() -> bool: def is_bip39() -> bool:
@ -43,8 +42,8 @@ def get_seed(passphrase: str = "", progress_bar: bool = True) -> bytes:
seed = bip39.seed(mnemonic_secret.decode(), passphrase, render_func) seed = bip39.seed(mnemonic_secret.decode(), passphrase, render_func)
else: # SLIP-39 else: # SLIP-39
identifier = storage_device.get_slip39_identifier() identifier = storage.device.get_slip39_identifier()
iteration_exponent = storage_device.get_slip39_iteration_exponent() iteration_exponent = storage.device.get_slip39_iteration_exponent()
if identifier is None or iteration_exponent is None: if identifier is None or iteration_exponent is None:
# Identifier or exponent expected but not found # Identifier or exponent expected but not found
raise RuntimeError raise RuntimeError

@ -1,5 +1,7 @@
from micropython import const from micropython import const
import storage.device
from storage import cache
from trezor import ui, wire from trezor import ui, wire
from trezor.messages import ButtonRequestType, PassphraseSourceType from trezor.messages import ButtonRequestType, PassphraseSourceType
from trezor.messages.ButtonAck import ButtonAck from trezor.messages.ButtonAck import ButtonAck
@ -12,9 +14,6 @@ from trezor.ui.passphrase import CANCELLED, PassphraseKeyboard, PassphraseSource
from trezor.ui.popup import Popup from trezor.ui.popup import Popup
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common import cache
from apps.common.storage import device as storage_device
if __debug__: if __debug__:
from apps.debug import input_signal from apps.debug import input_signal
@ -22,14 +21,14 @@ _MAX_PASSPHRASE_LEN = const(50)
async def protect_by_passphrase(ctx: wire.Context) -> str: async def protect_by_passphrase(ctx: wire.Context) -> str:
if storage_device.has_passphrase(): if storage.device.has_passphrase():
return await request_passphrase(ctx) return await request_passphrase(ctx)
else: else:
return "" return ""
async def request_passphrase(ctx: wire.Context) -> str: async def request_passphrase(ctx: wire.Context) -> str:
source = storage_device.get_passphrase_source() source = storage.device.get_passphrase_source()
if source == PassphraseSourceType.ASK: if source == PassphraseSourceType.ASK:
source = await request_passphrase_source(ctx) source = await request_passphrase_source(ctx)
passphrase = await request_passphrase_ack( passphrase = await request_passphrase_ack(

@ -8,7 +8,6 @@ from trezor.ui.popup import Popup
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.sd_salt import request_sd_salt from apps.common.sd_salt import request_sd_salt
from apps.common.storage import device
if False: if False:
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
@ -81,11 +80,7 @@ async def pin_mismatch() -> None:
async def request_pin_and_sd_salt( async def request_pin_and_sd_salt(
ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
) -> Tuple[str, Optional[bytearray]]: ) -> Tuple[str, Optional[bytearray]]:
salt_auth_key = device.get_sd_salt_auth_key() salt = await request_sd_salt(ctx)
if salt_auth_key is not None:
salt = await request_sd_salt(ctx, salt_auth_key) # type: Optional[bytearray]
else:
salt = None
if config.has_pin(): if config.has_pin():
pin = await request_pin_ack(ctx, prompt, config.get_pin_rem(), allow_cancel) pin = await request_pin_ack(ctx, prompt, config.get_pin_rem(), allow_cancel)
@ -98,11 +93,7 @@ async def request_pin_and_sd_salt(
async def verify_user_pin( async def verify_user_pin(
prompt: str = "Enter your PIN", allow_cancel: bool = True, retry: bool = True prompt: str = "Enter your PIN", allow_cancel: bool = True, retry: bool = True
) -> None: ) -> None:
salt_auth_key = device.get_sd_salt_auth_key() salt = await request_sd_salt()
if salt_auth_key is not None:
salt = await request_sd_salt(None, salt_auth_key) # type: Optional[bytearray]
else:
salt = None
if not config.has_pin() and not config.check_pin(pin_to_int(""), salt): if not config.has_pin() and not config.check_pin(pin_to_int(""), salt):
raise RuntimeError raise RuntimeError

@ -1,11 +1,7 @@
from micropython import const import storage.sd_salt
from storage.sd_salt import SD_CARD_HOT_SWAPPABLE, SdSaltMismatch
from trezor import io, ui, wire from trezor import io, ui, wire
from trezor.crypto import hmac
from trezor.crypto.hashlib import sha256
from trezor.ui.confirm import CONFIRMED, Confirm
from trezor.ui.text import Text from trezor.ui.text import Text
from trezor.utils import consteq
from apps.common.confirm import confirm from apps.common.confirm import confirm
@ -17,13 +13,7 @@ class SdProtectCancelled(Exception):
pass pass
SD_CARD_HOT_SWAPPABLE = False async def _wrong_card_dialog(ctx: wire.GenericContext) -> bool:
SD_SALT_LEN_BYTES = const(32)
SD_SALT_AUTH_TAG_LEN_BYTES = const(16)
SD_SALT_AUTH_KEY_LEN_BYTES = const(16)
async def _wrong_card_dialog(ctx: Optional[wire.Context]) -> None:
text = Text("SD card protection", ui.ICON_WRONG) text = Text("SD card protection", ui.ICON_WRONG)
text.bold("Wrong SD card.") text.bold("Wrong SD card.")
text.br_half() text.br_half()
@ -36,15 +26,10 @@ async def _wrong_card_dialog(ctx: Optional[wire.Context]) -> None:
btn_confirm = None btn_confirm = None
btn_cancel = "Close" btn_cancel = "Close"
if ctx is None: return await confirm(ctx, text, confirm=btn_confirm, cancel=btn_cancel)
if await Confirm(text, confirm=btn_confirm, cancel=btn_cancel) is not CONFIRMED:
raise SdProtectCancelled
else:
if not await confirm(ctx, text, confirm=btn_confirm, cancel=btn_cancel):
raise wire.ProcessError("Wrong SD card.")
async def _insert_card_dialog(ctx: Optional[wire.Context]) -> None: async def _insert_card_dialog(ctx: wire.GenericContext) -> None:
text = Text("SD card protection", ui.ICON_WRONG) text = Text("SD card protection", ui.ICON_WRONG)
text.bold("SD card required.") text.bold("SD card required.")
text.br_half() text.br_half()
@ -57,171 +42,34 @@ async def _insert_card_dialog(ctx: Optional[wire.Context]) -> None:
btn_confirm = None btn_confirm = None
btn_cancel = "Close" btn_cancel = "Close"
if ctx is None: if not await confirm(ctx, text, confirm=btn_confirm, cancel=btn_cancel):
if await Confirm(text, confirm=btn_confirm, cancel=btn_cancel) is not CONFIRMED: raise SdProtectCancelled
raise SdProtectCancelled
else:
if not await confirm(ctx, text, confirm=btn_confirm, cancel=btn_cancel):
raise wire.ProcessError("SD card required.")
async def _write_failed_dialog(ctx: Optional[wire.Context]) -> None: async def sd_write_failed_dialog(ctx: wire.GenericContext) -> bool:
text = Text("SD card protection", ui.ICON_WRONG, ui.RED) text = Text("SD card protection", ui.ICON_WRONG, ui.RED)
text.normal("Failed to write data to", "the SD card.") text.normal("Failed to write data to", "the SD card.")
if ctx is None: return await confirm(ctx, text, confirm="Retry", cancel="Abort")
if await Confirm(text, confirm="Retry", cancel="Abort") is not CONFIRMED:
raise OSError
else:
if not await confirm(ctx, text, confirm="Retry", cancel="Abort"):
raise wire.ProcessError("Failed to write to SD card.")
def _get_device_dir() -> str:
from apps.common.storage.device import get_device_id
return "/trezor/device_%s" % get_device_id().lower()
def _get_salt_path(new: bool = False) -> str: async def ensure_sd_card(ctx: wire.GenericContext) -> None:
if new: sd = io.SDCard()
return "%s/salt.new" % _get_device_dir() while not sd.power(True):
else: await _insert_card_dialog(ctx)
return "%s/salt" % _get_device_dir()
def _load_salt(fs: io.FatFS, auth_key: bytes, path: str) -> Optional[bytearray]:
# Load the salt file if it exists.
try:
with fs.open(path, "r") as f:
salt = bytearray(SD_SALT_LEN_BYTES)
stored_tag = bytearray(SD_SALT_AUTH_TAG_LEN_BYTES)
f.read(salt)
f.read(stored_tag)
except OSError:
return None
# Check the salt's authentication tag.
computed_tag = hmac.new(auth_key, salt, sha256).digest()[
:SD_SALT_AUTH_TAG_LEN_BYTES
]
if not consteq(computed_tag, stored_tag):
return None
return salt
async def request_sd_salt( async def request_sd_salt(
ctx: Optional[wire.Context], salt_auth_key: bytes ctx: wire.GenericContext = wire.DUMMY_CONTEXT
) -> bytearray: ) -> Optional[bytearray]:
salt_path = _get_salt_path()
new_salt_path = _get_salt_path(True)
while True:
sd = io.SDCard()
fs = io.FatFS()
while not sd.power(True):
await _insert_card_dialog(ctx)
try:
fs.mount()
salt = _load_salt(fs, salt_auth_key, salt_path)
if salt is not None:
return salt
# Check if there is a new salt.
salt = _load_salt(fs, salt_auth_key, new_salt_path)
if salt is not None:
# SD salt regeneration was interrupted earlier. Bring into consistent state.
# TODO Possibly overwrite salt file with random data.
try:
fs.unlink(salt_path)
except OSError:
pass
try:
fs.rename(new_salt_path, salt_path)
except OSError:
error_dialog = _write_failed_dialog(ctx)
else:
return salt
else:
# No valid salt file on this SD card.
error_dialog = _wrong_card_dialog(ctx)
finally:
fs.unmount()
sd.power(False)
await error_dialog
async def set_sd_salt(
ctx: Optional[wire.Context], salt: bytes, salt_tag: bytes, new: bool = False
) -> None:
salt_path = _get_salt_path(new)
while True: while True:
sd = io.SDCard() ensure_sd_card(ctx)
while not sd.power(True):
await _insert_card_dialog(ctx)
try: try:
fs = io.FatFS() return storage.sd_salt.load_sd_salt()
fs.mount() except SdSaltMismatch as e:
fs.mkdir("/trezor", True) if not await _wrong_card_dialog(ctx):
fs.mkdir(_get_device_dir(), True) raise SdProtectCancelled from e
with fs.open(salt_path, "w") as f:
f.write(salt)
f.write(salt_tag)
break
except Exception:
fs.unmount()
sd.power(False)
await _write_failed_dialog(ctx)
fs.unmount()
sd.power(False)
async def stage_sd_salt(
ctx: Optional[wire.Context], salt: bytes, salt_tag: bytes
) -> None:
await set_sd_salt(ctx, salt, salt_tag, True)
async def commit_sd_salt(ctx: Optional[wire.Context]) -> None:
salt_path = _get_salt_path()
new_salt_path = _get_salt_path(True)
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
raise OSError
try:
fs.mount()
# TODO Possibly overwrite salt file with random data.
try:
fs.unlink(salt_path)
except OSError: except OSError:
pass # This happens when there is both old and new salt file, and we can't move
fs.rename(new_salt_path, salt_path) # new salt over the old salt. If the user clicks Retry, we will try again.
finally: if not await sd_write_failed_dialog(ctx):
fs.unmount() raise
sd.power(False)
async def remove_sd_salt(ctx: Optional[wire.Context]) -> None:
salt_path = _get_salt_path()
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
raise OSError
try:
fs.mount()
# TODO Possibly overwrite salt file with random data.
fs.unlink(salt_path)
finally:
fs.unmount()
sd.power(False)

@ -1,8 +1,10 @@
import storage
import storage.cache
from trezor import wire from trezor import wire
from trezor.crypto import bip32, hashlib, hmac from trezor.crypto import bip32, hashlib, hmac
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from apps.common import HARDENED, cache, mnemonic, storage from apps.common import HARDENED, mnemonic
from apps.common.request_passphrase import protect_by_passphrase from apps.common.request_passphrase import protect_by_passphrase
if False: if False:
@ -110,14 +112,14 @@ class Keychain:
async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain: async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain:
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
seed = cache.get_seed() seed = storage.cache.get_seed()
if seed is None: if seed is None:
passphrase = cache.get_passphrase() passphrase = storage.cache.get_passphrase()
if passphrase is None: if passphrase is None:
passphrase = await protect_by_passphrase(ctx) passphrase = await protect_by_passphrase(ctx)
cache.set_passphrase(passphrase) storage.cache.set_passphrase(passphrase)
seed = mnemonic.get_seed(passphrase) seed = mnemonic.get_seed(passphrase)
cache.set_seed(seed) storage.cache.set_seed(seed)
keychain = Keychain(seed, namespaces) keychain = Keychain(seed, namespaces)
return keychain return keychain
@ -127,10 +129,10 @@ def derive_node_without_passphrase(
) -> bip32.HDNode: ) -> bip32.HDNode:
if not storage.is_initialized(): if not storage.is_initialized():
raise Exception("Device is not initialized") raise Exception("Device is not initialized")
seed = cache.get_seed_without_passphrase() seed = storage.cache.get_seed_without_passphrase()
if seed is None: if seed is None:
seed = mnemonic.get_seed(progress_bar=False) seed = mnemonic.get_seed(progress_bar=False)
cache.set_seed_without_passphrase(seed) storage.cache.set_seed_without_passphrase(seed)
node = bip32.from_seed(seed, curve_name) node = bip32.from_seed(seed, curve_name)
node.derive_path(path) node.derive_path(path)
return node return node
@ -139,10 +141,10 @@ def derive_node_without_passphrase(
def derive_slip21_node_without_passphrase(path: list) -> Slip21Node: def derive_slip21_node_without_passphrase(path: list) -> Slip21Node:
if not storage.is_initialized(): if not storage.is_initialized():
raise Exception("Device is not initialized") raise Exception("Device is not initialized")
seed = cache.get_seed_without_passphrase() seed = storage.cache.get_seed_without_passphrase()
if seed is None: if seed is None:
seed = mnemonic.get_seed(progress_bar=False) seed = mnemonic.get_seed(progress_bar=False)
cache.set_seed_without_passphrase(seed) storage.cache.set_seed_without_passphrase(seed)
node = Slip21Node(seed) node = Slip21Node(seed)
node.derive_path(path) node.derive_path(path)
return node return node

@ -1,97 +0,0 @@
from micropython import const
from apps.common.storage import common
from apps.webauthn.credential import Credential, Fido2Credential
if False:
from typing import List, Optional
_RESIDENT_CREDENTIAL_START_KEY = const(1)
_MAX_RESIDENT_CREDENTIALS = const(100)
def get_resident_credentials(rp_id_hash: Optional[bytes] = None) -> List[Credential]:
creds = [] # type: List[Credential]
for i in range(_MAX_RESIDENT_CREDENTIALS):
cred = get_resident_credential(i, rp_id_hash)
if cred is not None:
creds.append(cred)
return creds
def get_resident_credential(
index: int, rp_id_hash: Optional[bytes] = None
) -> Optional[Credential]:
if not (0 <= index < _MAX_RESIDENT_CREDENTIALS):
return None
stored_cred_data = common.get(
common.APP_WEBAUTHN, index + _RESIDENT_CREDENTIAL_START_KEY
)
if stored_cred_data is None:
return None
stored_rp_id_hash = stored_cred_data[:32]
stored_cred_id = stored_cred_data[32:]
if rp_id_hash is not None and rp_id_hash != stored_rp_id_hash:
# Stored credential is not for this RP ID.
return None
stored_cred = Fido2Credential.from_cred_id(stored_cred_id, stored_rp_id_hash)
if stored_cred is None:
return None
stored_cred.index = index
return stored_cred
def store_resident_credential(cred: Fido2Credential) -> bool:
slot = None
for i in range(_MAX_RESIDENT_CREDENTIALS):
stored_cred_data = common.get(
common.APP_WEBAUTHN, i + _RESIDENT_CREDENTIAL_START_KEY
)
if stored_cred_data is None:
if slot is None:
slot = i
continue
stored_rp_id_hash = stored_cred_data[:32]
stored_cred_id = stored_cred_data[32:]
if cred.rp_id_hash != stored_rp_id_hash:
# Stored credential is not for this RP ID.
continue
stored_cred = Fido2Credential.from_cred_id(stored_cred_id, stored_rp_id_hash)
if stored_cred is None:
# Stored credential is not for this RP ID.
continue
# If a credential for the same RP ID and user ID already exists, then overwrite it.
if stored_cred.user_id == cred.user_id:
slot = i
break
if slot is None:
return False
common.set(
common.APP_WEBAUTHN,
slot + _RESIDENT_CREDENTIAL_START_KEY,
cred.rp_id_hash + cred.id,
)
return True
def erase_resident_credentials() -> None:
for i in range(_MAX_RESIDENT_CREDENTIALS):
common.delete(common.APP_WEBAUTHN, i + _RESIDENT_CREDENTIAL_START_KEY)
def erase_resident_credential(index: int) -> bool:
if not (0 <= index < _MAX_RESIDENT_CREDENTIALS):
return False
common.delete(common.APP_WEBAUTHN, index + _RESIDENT_CREDENTIAL_START_KEY)
return True

@ -86,8 +86,8 @@ if __debug__:
ctx: wire.Context, msg: DebugLinkGetState ctx: wire.Context, msg: DebugLinkGetState
) -> DebugLinkState: ) -> DebugLinkState:
from trezor.messages.DebugLinkState import DebugLinkState from trezor.messages.DebugLinkState import DebugLinkState
from storage.device import has_passphrase
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.storage.device import has_passphrase
m = DebugLinkState() m = DebugLinkState()
m.mnemonic_secret = mnemonic.get_secret() m.mnemonic_secret = mnemonic.get_secret()

@ -1,11 +1,15 @@
import storage
import storage.device
import storage.recovery
import storage.sd_salt
from storage import cache
from trezor import config, io, utils, wire from trezor import config, io, utils, wire
from trezor.messages import Capability, MessageType from trezor.messages import Capability, MessageType
from trezor.messages.Features import Features from trezor.messages.Features import Features
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.wire import register from trezor.wire import register
from apps.common import cache, mnemonic, storage from apps.common import mnemonic
from apps.common.storage import device as storage_device, recovery as storage_recovery
if False: if False:
from typing import NoReturn from typing import NoReturn
@ -25,18 +29,18 @@ def get_features() -> Features:
f.patch_version = utils.VERSION_PATCH f.patch_version = utils.VERSION_PATCH
f.revision = utils.GITREV.encode() f.revision = utils.GITREV.encode()
f.model = utils.MODEL f.model = utils.MODEL
f.device_id = storage_device.get_device_id() f.device_id = storage.device.get_device_id()
f.label = storage_device.get_label() f.label = storage.device.get_label()
f.initialized = storage.is_initialized() f.initialized = storage.is_initialized()
f.pin_protection = config.has_pin() f.pin_protection = config.has_pin()
f.pin_cached = config.has_pin() f.pin_cached = config.has_pin()
f.passphrase_protection = storage_device.has_passphrase() f.passphrase_protection = storage.device.has_passphrase()
f.passphrase_cached = cache.has_passphrase() f.passphrase_cached = cache.has_passphrase()
f.needs_backup = storage_device.needs_backup() f.needs_backup = storage.device.needs_backup()
f.unfinished_backup = storage_device.unfinished_backup() f.unfinished_backup = storage.device.unfinished_backup()
f.no_backup = storage_device.no_backup() f.no_backup = storage.device.no_backup()
f.flags = storage_device.get_flags() f.flags = storage.device.get_flags()
f.recovery_mode = storage_recovery.is_in_progress() f.recovery_mode = storage.recovery.is_in_progress()
f.backup_type = mnemonic.get_type() f.backup_type = mnemonic.get_type()
if utils.BITCOIN_ONLY: if utils.BITCOIN_ONLY:
f.capabilities = [ f.capabilities = [
@ -65,7 +69,7 @@ def get_features() -> Features:
Capability.ShamirGroups, Capability.ShamirGroups,
] ]
f.sd_card_present = io.SDCard().present() f.sd_card_present = io.SDCard().present()
f.sd_protection = storage.device.get_sd_salt_auth_key() is not None f.sd_protection = storage.sd_salt.is_enabled()
return f return f

@ -1,8 +1,7 @@
import storage
import storage.device
from trezor import config, res, ui from trezor import config, res, ui
from apps.common import storage
from apps.common.storage import device as storage_device
async def homescreen() -> None: async def homescreen() -> None:
await Homescreen() await Homescreen()
@ -20,17 +19,17 @@ class Homescreen(ui.Layout):
if not storage.is_initialized(): if not storage.is_initialized():
label = "Go to trezor.io/start" label = "Go to trezor.io/start"
else: else:
label = storage_device.get_label() or "My Trezor" label = storage.device.get_label() or "My Trezor"
image = storage_device.get_homescreen() image = storage.device.get_homescreen()
if not image: if not image:
image = res.load("apps/homescreen/res/bg.toif") image = res.load("apps/homescreen/res/bg.toif")
if storage.is_initialized() and storage_device.no_backup(): if storage.is_initialized() and storage.device.no_backup():
ui.header_error("SEEDLESS") ui.header_error("SEEDLESS")
elif storage.is_initialized() and storage_device.unfinished_backup(): elif storage.is_initialized() and storage.device.unfinished_backup():
ui.header_error("BACKUP FAILED!") ui.header_error("BACKUP FAILED!")
elif storage.is_initialized() and storage_device.needs_backup(): elif storage.is_initialized() and storage.device.needs_backup():
ui.header_warning("NEEDS BACKUP!") ui.header_warning("NEEDS BACKUP!")
elif storage.is_initialized() and not config.has_pin(): elif storage.is_initialized() and not config.has_pin():
ui.header_warning("PIN NOT SET!") ui.header_warning("PIN NOT SET!")

@ -1,7 +1,6 @@
from storage.device import set_flags
from trezor.messages.Success import Success from trezor.messages.Success import Success
from apps.common.storage.device import set_flags
async def apply_flags(ctx, msg): async def apply_flags(ctx, msg):
set_flags(msg.flags) set_flags(msg.flags)

@ -1,10 +1,10 @@
import storage.device
from trezor import ui, wire from trezor import ui, wire
from trezor.messages import ButtonRequestType, PassphraseSourceType from trezor.messages import ButtonRequestType, PassphraseSourceType
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.storage import device as storage_device
async def apply_settings(ctx, msg): async def apply_settings(ctx, msg):
@ -18,7 +18,7 @@ async def apply_settings(ctx, msg):
raise wire.ProcessError("No setting provided") raise wire.ProcessError("No setting provided")
if msg.homescreen is not None: if msg.homescreen is not None:
if len(msg.homescreen) > storage_device.HOMESCREEN_MAXSIZE: if len(msg.homescreen) > storage.device.HOMESCREEN_MAXSIZE:
raise wire.DataError("Homescreen is too complex") raise wire.DataError("Homescreen is too complex")
await require_confirm_change_homescreen(ctx) await require_confirm_change_homescreen(ctx)
@ -34,7 +34,7 @@ async def apply_settings(ctx, msg):
if msg.display_rotation is not None: if msg.display_rotation is not None:
await require_confirm_change_display_rotation(ctx, msg.display_rotation) await require_confirm_change_display_rotation(ctx, msg.display_rotation)
storage_device.load_settings( storage.device.load_settings(
label=msg.label, label=msg.label,
use_passphrase=msg.use_passphrase, use_passphrase=msg.use_passphrase,
homescreen=msg.homescreen, homescreen=msg.homescreen,
@ -43,7 +43,7 @@ async def apply_settings(ctx, msg):
) )
if msg.display_rotation is not None: if msg.display_rotation is not None:
ui.display.orientation(storage_device.get_rotation()) ui.display.orientation(storage.device.get_rotation())
return Success(message="Settings applied") return Success(message="Settings applied")

@ -1,25 +1,26 @@
import storage
import storage.device
from trezor import wire from trezor import wire
from trezor.messages.Success import Success from trezor.messages.Success import Success
from apps.common import mnemonic, storage from apps.common import mnemonic
from apps.common.storage import device as storage_device
from apps.management.reset_device import backup_seed, layout from apps.management.reset_device import backup_seed, layout
async def backup_device(ctx, msg): async def backup_device(ctx, msg):
if not storage.is_initialized(): if not storage.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if not storage_device.needs_backup(): if not storage.device.needs_backup():
raise wire.ProcessError("Seed already backed up") raise wire.ProcessError("Seed already backed up")
mnemonic_secret, mnemonic_type = mnemonic.get() mnemonic_secret, mnemonic_type = mnemonic.get()
storage_device.set_unfinished_backup(True) storage.device.set_unfinished_backup(True)
storage_device.set_backed_up() storage.device.set_backed_up()
await backup_seed(ctx, mnemonic_type, mnemonic_secret) await backup_seed(ctx, mnemonic_type, mnemonic_secret)
storage_device.set_unfinished_backup(False) storage.device.set_unfinished_backup(False)
await layout.show_backup_success(ctx) await layout.show_backup_success(ctx)

@ -1,3 +1,4 @@
from storage import is_initialized
from trezor import config, ui, wire from trezor import config, ui, wire
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
@ -10,7 +11,6 @@ from apps.common.request_pin import (
request_pin_confirm, request_pin_confirm,
show_pin_invalid, show_pin_invalid,
) )
from apps.common.storage import is_initialized
if False: if False:
from trezor.messages.ChangePin import ChangePin from trezor.messages.ChangePin import ChangePin

@ -1,3 +1,5 @@
import storage
import storage.device
from trezor import config, wire from trezor import config, wire
from trezor.crypto import bip39, slip39 from trezor.crypto import bip39, slip39
from trezor.messages import BackupType from trezor.messages import BackupType
@ -5,9 +7,7 @@ from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common import storage
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.storage import device as storage_device
from apps.management import backup_types from apps.management import backup_types
@ -33,13 +33,13 @@ async def load_device(ctx, msg):
backup_type = BackupType.Slip39_Advanced backup_type = BackupType.Slip39_Advanced
else: else:
raise RuntimeError("Invalid group count") raise RuntimeError("Invalid group count")
storage_device.set_slip39_identifier(identifier) storage.device.set_slip39_identifier(identifier)
storage_device.set_slip39_iteration_exponent(iteration_exponent) storage.device.set_slip39_iteration_exponent(iteration_exponent)
storage_device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, backup_type, needs_backup=True, no_backup=False secret, backup_type, needs_backup=True, no_backup=False
) )
storage_device.load_settings( storage.device.load_settings(
use_passphrase=msg.passphrase_protection, label=msg.label use_passphrase=msg.passphrase_protection, label=msg.label
) )
if msg.pin: if msg.pin:

@ -1,17 +1,18 @@
import storage
import storage.device
import storage.recovery
from trezor import config, ui, wire from trezor import config, ui, wire
from trezor.messages import ButtonRequestType from trezor.messages import ButtonRequestType
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common import storage
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.request_pin import ( from apps.common.request_pin import (
request_pin_and_sd_salt, request_pin_and_sd_salt,
request_pin_confirm, request_pin_confirm,
show_pin_invalid, show_pin_invalid,
) )
from apps.common.storage import device as storage_device, recovery as storage_recovery
from apps.management.recovery_device.homescreen import recovery_process from apps.management.recovery_device.homescreen import recovery_process
if False: if False:
@ -44,13 +45,13 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None) config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
if msg.u2f_counter: if msg.u2f_counter:
storage_device.set_u2f_counter(msg.u2f_counter) storage.device.set_u2f_counter(msg.u2f_counter)
storage_device.load_settings( storage.device.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection label=msg.label, use_passphrase=msg.passphrase_protection
) )
storage_recovery.set_in_progress(True) storage.recovery.set_in_progress(True)
if msg.dry_run: if msg.dry_run:
storage_recovery.set_dry_run(msg.dry_run) storage.recovery.set_dry_run(msg.dry_run)
result = await recovery_process(ctx) result = await recovery_process(ctx)
@ -63,7 +64,7 @@ def _check_state(msg: RecoveryDevice) -> None:
if msg.dry_run and not storage.is_initialized(): if msg.dry_run and not storage.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if storage_recovery.is_in_progress(): if storage.recovery.is_in_progress():
raise RuntimeError( raise RuntimeError(
"Function recovery_device should not be invoked when recovery is already in progress" "Function recovery_device should not be invoked when recovery is already in progress"
) )

@ -1,3 +1,7 @@
import storage
import storage.device
import storage.recovery
import storage.recovery_shares
from trezor import loop, utils, wire from trezor import loop, utils, wire
from trezor.crypto import slip39 from trezor.crypto import slip39
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
@ -7,13 +11,8 @@ from trezor.messages.Success import Success
from . import recover from . import recover
from apps.common import mnemonic, storage from apps.common import mnemonic
from apps.common.layout import show_success from apps.common.layout import show_success
from apps.common.storage import (
device as storage_device,
recovery as storage_recovery,
recovery_shares as storage_recovery_shares,
)
from apps.management import backup_types from apps.management import backup_types
from apps.management.recovery_device import layout from apps.management.recovery_device import layout
@ -38,9 +37,9 @@ async def recovery_process(ctx: wire.GenericContext) -> Success:
try: try:
result = await _continue_recovery_process(ctx) result = await _continue_recovery_process(ctx)
except recover.RecoveryAborted: except recover.RecoveryAborted:
dry_run = storage_recovery.is_dry_run() dry_run = storage.recovery.is_dry_run()
if dry_run: if dry_run:
storage_recovery.end_progress() storage.recovery.end_progress()
else: else:
storage.wipe() storage.wipe()
raise wire.ActionCancelled("Cancelled") raise wire.ActionCancelled("Cancelled")
@ -49,7 +48,7 @@ async def recovery_process(ctx: wire.GenericContext) -> Success:
async def _continue_recovery_process(ctx: wire.GenericContext) -> Success: async def _continue_recovery_process(ctx: wire.GenericContext) -> Success:
# gather the current recovery state from storage # gather the current recovery state from storage
dry_run = storage_recovery.is_dry_run() dry_run = storage.recovery.is_dry_run()
word_count, backup_type = recover.load_slip39_state() word_count, backup_type = recover.load_slip39_state()
# Both word_count and backup_type are derived from the same data. Both will be # Both word_count and backup_type are derived from the same data. Both will be
@ -112,17 +111,17 @@ async def _finish_recovery_dry_run(
# Check that the identifier and iteration exponent match as well # Check that the identifier and iteration exponent match as well
if is_slip39: if is_slip39:
result &= ( result &= (
storage_device.get_slip39_identifier() storage.device.get_slip39_identifier()
== storage_recovery.get_slip39_identifier() == storage.recovery.get_slip39_identifier()
) )
result &= ( result &= (
storage_device.get_slip39_iteration_exponent() storage.device.get_slip39_iteration_exponent()
== storage_recovery.get_slip39_iteration_exponent() == storage.recovery.get_slip39_iteration_exponent()
) )
await layout.show_dry_run_result(ctx, result, is_slip39) await layout.show_dry_run_result(ctx, result, is_slip39)
storage_recovery.end_progress() storage.recovery.end_progress()
if result: if result:
return Success("The seed is valid and matches the one in the device") return Success("The seed is valid and matches the one in the device")
@ -136,21 +135,21 @@ async def _finish_recovery(
if backup_type is None: if backup_type is None:
raise RuntimeError raise RuntimeError
storage_device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, backup_type, needs_backup=False, no_backup=False secret, backup_type, needs_backup=False, no_backup=False
) )
if backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced): if backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced):
identifier = storage_recovery.get_slip39_identifier() identifier = storage.recovery.get_slip39_identifier()
exponent = storage_recovery.get_slip39_iteration_exponent() exponent = storage.recovery.get_slip39_iteration_exponent()
if identifier is None or exponent is None: if identifier is None or exponent is None:
# Identifier and exponent need to be stored in storage at this point # Identifier and exponent need to be stored in storage at this point
raise RuntimeError raise RuntimeError
storage_device.set_slip39_identifier(identifier) storage.device.set_slip39_identifier(identifier)
storage_device.set_slip39_iteration_exponent(exponent) storage.device.set_slip39_iteration_exponent(exponent)
await show_success(ctx, ("You have successfully", "recovered your wallet.")) await show_success(ctx, ("You have successfully", "recovered your wallet."))
storage_recovery.end_progress() storage.recovery.end_progress()
return Success(message="Device recovered") return Success(message="Device recovered")
@ -188,7 +187,7 @@ async def _request_share_first_screen(
ctx: wire.GenericContext, word_count: int ctx: wire.GenericContext, word_count: int
) -> None: ) -> None:
if backup_types.is_slip39_word_count(word_count): if backup_types.is_slip39_word_count(word_count):
remaining = storage_recovery.fetch_slip39_remaining_shares() remaining = storage.recovery.fetch_slip39_remaining_shares()
if remaining: if remaining:
await _request_share_next_screen(ctx) await _request_share_next_screen(ctx)
else: else:
@ -204,8 +203,8 @@ async def _request_share_first_screen(
async def _request_share_next_screen(ctx: wire.GenericContext) -> None: async def _request_share_next_screen(ctx: wire.GenericContext) -> None:
remaining = storage_recovery.fetch_slip39_remaining_shares() remaining = storage.recovery.fetch_slip39_remaining_shares()
group_count = storage_recovery.get_slip39_group_count() group_count = storage.recovery.get_slip39_group_count()
if not remaining: if not remaining:
# 'remaining' should be stored at this point # 'remaining' should be stored at this point
raise RuntimeError raise RuntimeError
@ -228,7 +227,7 @@ async def _show_remaining_groups_and_shares(ctx: wire.GenericContext) -> None:
""" """
Show info dialog for Slip39 Advanced - what shares are to be entered. Show info dialog for Slip39 Advanced - what shares are to be entered.
""" """
shares_remaining = storage_recovery.fetch_slip39_remaining_shares() shares_remaining = storage.recovery.fetch_slip39_remaining_shares()
# should be stored at this point # should be stored at this point
assert shares_remaining assert shares_remaining
@ -241,13 +240,13 @@ async def _show_remaining_groups_and_shares(ctx: wire.GenericContext) -> None:
share = None share = None
for index, remaining in enumerate(shares_remaining): for index, remaining in enumerate(shares_remaining):
if 0 <= remaining < slip39.MAX_SHARE_COUNT: if 0 <= remaining < slip39.MAX_SHARE_COUNT:
m = storage_recovery_shares.fetch_group(index)[0] m = storage.recovery_shares.fetch_group(index)[0]
if not share: if not share:
share = slip39.decode_mnemonic(m) share = slip39.decode_mnemonic(m)
identifier = m.split(" ")[0:3] identifier = m.split(" ")[0:3]
groups.add((remaining, tuple(identifier))) groups.add((remaining, tuple(identifier)))
elif remaining == slip39.MAX_SHARE_COUNT: # no shares yet elif remaining == slip39.MAX_SHARE_COUNT: # no shares yet
identifier = storage_recovery_shares.fetch_group(first_entered_index)[ identifier = storage.recovery_shares.fetch_group(first_entered_index)[
0 0
].split(" ")[0:2] ].split(" ")[0:2]
groups.add((remaining, tuple(identifier))) groups.add((remaining, tuple(identifier)))

@ -1,3 +1,4 @@
import storage.recovery
from trezor import ui, wire from trezor import ui, wire
from trezor.crypto.slip39 import MAX_SHARE_COUNT from trezor.crypto.slip39 import MAX_SHARE_COUNT
from trezor.messages import BackupType, ButtonRequestType from trezor.messages import BackupType, ButtonRequestType
@ -13,7 +14,6 @@ from .recover import RecoveryAborted
from apps.common.confirm import confirm, info_confirm, require_confirm from apps.common.confirm import confirm, info_confirm, require_confirm
from apps.common.layout import show_success, show_warning from apps.common.layout import show_success, show_warning
from apps.common.storage import recovery as storage_recovery
from apps.management import backup_types from apps.management import backup_types
from apps.management.recovery_device import recover from apps.management.recovery_device import recover
@ -127,7 +127,7 @@ async def check_word_validity(
if len(group) > 0: if len(group) > 0:
if current_word == group[0].split(" ")[current_index]: if current_word == group[0].split(" ")[current_index]:
remaining_shares = ( remaining_shares = (
storage_recovery.fetch_slip39_remaining_shares() storage.recovery.fetch_slip39_remaining_shares()
) )
# if backup_type is not None, some share was already entered -> remaining needs to be set # if backup_type is not None, some share was already entered -> remaining needs to be set
assert remaining_shares is not None assert remaining_shares is not None
@ -280,7 +280,7 @@ class RecoveryHomescreen(ui.Component):
def __init__(self, text: str, subtext: str = None): def __init__(self, text: str, subtext: str = None):
self.text = text self.text = text
self.subtext = subtext self.subtext = subtext
self.dry_run = storage_recovery.is_dry_run() self.dry_run = storage.recovery.is_dry_run()
self.repaint = True self.repaint = True
def on_render(self) -> None: def on_render(self) -> None:
@ -345,6 +345,6 @@ async def homescreen_dialog(
# go forward in the recovery process # go forward in the recovery process
break break
# user has chosen to abort, confirm the choice # user has chosen to abort, confirm the choice
dry_run = storage_recovery.is_dry_run() dry_run = storage.recovery.is_dry_run()
if await confirm_abort(ctx, dry_run): if await confirm_abort(ctx, dry_run):
raise RecoveryAborted raise RecoveryAborted

@ -1,10 +1,8 @@
import storage.recovery
import storage.recovery_shares
from trezor.crypto import bip39, slip39 from trezor.crypto import bip39, slip39
from trezor.errors import MnemonicError from trezor.errors import MnemonicError
from apps.common.storage import (
recovery as storage_recovery,
recovery_shares as storage_recovery_shares,
)
from apps.management import backup_types from apps.management import backup_types
if False: if False:
@ -33,17 +31,17 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
""" """
share = slip39.decode_mnemonic(words) share = slip39.decode_mnemonic(words)
remaining = storage_recovery.fetch_slip39_remaining_shares() remaining = storage.recovery.fetch_slip39_remaining_shares()
# if this is the first share, parse and store metadata # if this is the first share, parse and store metadata
if not remaining: if not remaining:
storage_recovery.set_slip39_group_count(share.group_count) storage.recovery.set_slip39_group_count(share.group_count)
storage_recovery.set_slip39_iteration_exponent(share.iteration_exponent) storage.recovery.set_slip39_iteration_exponent(share.iteration_exponent)
storage_recovery.set_slip39_identifier(share.identifier) storage.recovery.set_slip39_identifier(share.identifier)
storage_recovery.set_slip39_remaining_shares( storage.recovery.set_slip39_remaining_shares(
share.threshold - 1, share.group_index share.threshold - 1, share.group_index
) )
storage_recovery_shares.set(share.index, share.group_index, words) storage.recovery_shares.set(share.index, share.group_index, words)
# if share threshold and group threshold are 1 # if share threshold and group threshold are 1
# we can calculate the secret right away # we can calculate the secret right away
@ -57,24 +55,24 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
return None, share return None, share
# These should be checked by UI before so it's a Runtime exception otherwise # These should be checked by UI before so it's a Runtime exception otherwise
if share.identifier != storage_recovery.get_slip39_identifier(): if share.identifier != storage.recovery.get_slip39_identifier():
raise RuntimeError("Slip39: Share identifiers do not match") raise RuntimeError("Slip39: Share identifiers do not match")
if share.iteration_exponent != storage_recovery.get_slip39_iteration_exponent(): if share.iteration_exponent != storage.recovery.get_slip39_iteration_exponent():
raise RuntimeError("Slip39: Share exponents do not match") raise RuntimeError("Slip39: Share exponents do not match")
if storage_recovery_shares.get(share.index, share.group_index): if storage.recovery_shares.get(share.index, share.group_index):
raise RuntimeError("Slip39: This mnemonic was already entered") raise RuntimeError("Slip39: This mnemonic was already entered")
if share.group_count != storage_recovery.get_slip39_group_count(): if share.group_count != storage.recovery.get_slip39_group_count():
raise RuntimeError("Slip39: Group count does not match") raise RuntimeError("Slip39: Group count does not match")
remaining_for_share = ( remaining_for_share = (
storage_recovery.get_slip39_remaining_shares(share.group_index) storage.recovery.get_slip39_remaining_shares(share.group_index)
or share.threshold or share.threshold
) )
storage_recovery.set_slip39_remaining_shares( storage.recovery.set_slip39_remaining_shares(
remaining_for_share - 1, share.group_index remaining_for_share - 1, share.group_index
) )
remaining[share.group_index] = remaining_for_share - 1 remaining[share.group_index] = remaining_for_share - 1
storage_recovery_shares.set(share.index, share.group_index, words) storage.recovery_shares.set(share.index, share.group_index, words)
if remaining.count(0) < share.group_threshold: if remaining.count(0) < share.group_threshold:
# we need more shares # we need more shares
@ -85,11 +83,11 @@ def process_slip39(words: str) -> Tuple[Optional[bytes], slip39.Share]:
for i, r in enumerate(remaining): for i, r in enumerate(remaining):
# if we have multiple groups pass only the ones with threshold reached # if we have multiple groups pass only the ones with threshold reached
if r == 0: if r == 0:
group = storage_recovery_shares.fetch_group(i) group = storage.recovery_shares.fetch_group(i)
mnemonics.extend(group) mnemonics.extend(group)
else: else:
# in case of slip39 basic we only need the first and only group # in case of slip39 basic we only need the first and only group
mnemonics = storage_recovery_shares.fetch_group(0) mnemonics = storage.recovery_shares.fetch_group(0)
identifier, iteration_exponent, secret, _ = slip39.combine_mnemonics(mnemonics) identifier, iteration_exponent, secret, _ = slip39.combine_mnemonics(mnemonics)
return secret, share return secret, share
@ -112,10 +110,10 @@ def load_slip39_state() -> Slip39State:
def fetch_previous_mnemonics() -> Optional[List[List[str]]]: def fetch_previous_mnemonics() -> Optional[List[List[str]]]:
mnemonics = [] mnemonics = []
if not storage_recovery.get_slip39_group_count(): if not storage.recovery.get_slip39_group_count():
return None return None
for i in range(storage_recovery.get_slip39_group_count()): for i in range(storage.recovery.get_slip39_group_count()):
mnemonics.append(storage_recovery_shares.fetch_group(i)) mnemonics.append(storage.recovery_shares.fetch_group(i))
if not any(p for p in mnemonics): if not any(p for p in mnemonics):
return None return None
return mnemonics return mnemonics

@ -1,3 +1,5 @@
import storage
import storage.device
from trezor import config, wire from trezor import config, wire
from trezor.crypto import bip39, hashlib, random, slip39 from trezor.crypto import bip39, hashlib, random, slip39
from trezor.messages import BackupType from trezor.messages import BackupType
@ -6,8 +8,6 @@ from trezor.messages.EntropyRequest import EntropyRequest
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
from apps.common import storage
from apps.common.storage import device as storage_device
from apps.management import backup_types from apps.management import backup_types
from apps.management.change_pin import request_pin_confirm from apps.management.change_pin import request_pin_confirm
from apps.management.reset_device import layout from apps.management.reset_device import layout
@ -53,8 +53,8 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
secret = bip39.from_data(secret).encode() secret = bip39.from_data(secret).encode()
elif msg.backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced): elif msg.backup_type in (BackupType.Slip39_Basic, BackupType.Slip39_Advanced):
# generate and set SLIP39 parameters # generate and set SLIP39 parameters
storage_device.set_slip39_identifier(slip39.generate_random_identifier()) storage.device.set_slip39_identifier(slip39.generate_random_identifier())
storage_device.set_slip39_iteration_exponent(slip39.DEFAULT_ITERATION_EXPONENT) storage.device.set_slip39_iteration_exponent(slip39.DEFAULT_ITERATION_EXPONENT)
else: else:
# Unknown backup type. # Unknown backup type.
raise RuntimeError raise RuntimeError
@ -72,10 +72,10 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
await backup_seed(ctx, msg.backup_type, secret) await backup_seed(ctx, msg.backup_type, secret)
# write settings and master secret into storage # write settings and master secret into storage
storage_device.load_settings( storage.device.load_settings(
label=msg.label, use_passphrase=msg.passphrase_protection label=msg.label, use_passphrase=msg.passphrase_protection
) )
storage_device.store_mnemonic_secret( storage.device.store_mnemonic_secret(
secret, # for SLIP-39, this is the EMS secret, # for SLIP-39, this is the EMS
msg.backup_type, msg.backup_type,
needs_backup=not perform_backup, needs_backup=not perform_backup,
@ -103,10 +103,10 @@ async def backup_slip39_basic(
# generate the mnemonics # generate the mnemonics
mnemonics = slip39.generate_mnemonics_from_data( mnemonics = slip39.generate_mnemonics_from_data(
encrypted_master_secret, encrypted_master_secret,
storage_device.get_slip39_identifier(), storage.device.get_slip39_identifier(),
1, # Single Group threshold 1, # Single Group threshold
[(threshold, shares_count)], # Single Group threshold/count [(threshold, shares_count)], # Single Group threshold/count
storage_device.get_slip39_iteration_exponent(), storage.device.get_slip39_iteration_exponent(),
)[0] )[0]
# show and confirm individual shares # show and confirm individual shares
@ -138,10 +138,10 @@ async def backup_slip39_advanced(
# generate the mnemonics # generate the mnemonics
mnemonics = slip39.generate_mnemonics_from_data( mnemonics = slip39.generate_mnemonics_from_data(
encrypted_master_secret=encrypted_master_secret, encrypted_master_secret=encrypted_master_secret,
identifier=storage_device.get_slip39_identifier(), identifier=storage.device.get_slip39_identifier(),
group_threshold=group_threshold, group_threshold=group_threshold,
groups=groups, groups=groups,
iteration_exponent=storage_device.get_slip39_iteration_exponent(), iteration_exponent=storage.device.get_slip39_iteration_exponent(),
) )
# show and confirm individual shares # show and confirm individual shares

@ -1,6 +1,7 @@
import storage.device
import storage.sd_salt
from trezor import config, ui, wire from trezor import config, ui, wire
from trezor.crypto import hmac, random from trezor.crypto import random
from trezor.crypto.hashlib import sha256
from trezor.messages import SdProtectOperationType from trezor.messages import SdProtectOperationType
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.pin import pin_to_int from trezor.pin import pin_to_int
@ -13,23 +14,33 @@ from apps.common.request_pin import (
request_pin_and_sd_salt, request_pin_and_sd_salt,
show_pin_invalid, show_pin_invalid,
) )
from apps.common.sd_salt import ( from apps.common.sd_salt import ensure_sd_card, sd_write_failed_dialog
SD_SALT_AUTH_KEY_LEN_BYTES,
SD_SALT_AUTH_TAG_LEN_BYTES,
SD_SALT_LEN_BYTES,
commit_sd_salt,
remove_sd_salt,
set_sd_salt,
stage_sd_salt,
)
from apps.common.storage import device, is_initialized
if False: if False:
from typing import Awaitable, Tuple
from trezor.messages.SdProtect import SdProtect from trezor.messages.SdProtect import SdProtect
def _make_salt() -> Tuple[bytes, bytes, bytes]:
salt = random.bytes(storage.sd_salt.SD_SALT_LEN_BYTES)
auth_key = random.bytes(storage.device.SD_SALT_AUTH_KEY_LEN_BYTES)
tag = storage.sd_salt.compute_auth_tag(salt, auth_key)
return salt, auth_key, tag
async def _set_salt(
ctx: wire.Context, salt: bytes, salt_tag: bytes, stage: bool = False
) -> None:
while True:
try:
return storage.sd_salt.set_sd_salt(salt, salt_tag, stage)
except OSError:
if not await sd_write_failed_dialog(ctx):
raise
async def sd_protect(ctx: wire.Context, msg: SdProtect) -> Success: async def sd_protect(ctx: wire.Context, msg: SdProtect) -> Success:
if not is_initialized(): if not storage.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if msg.operation == SdProtectOperationType.ENABLE: if msg.operation == SdProtectOperationType.ENABLE:
@ -43,13 +54,15 @@ async def sd_protect(ctx: wire.Context, msg: SdProtect) -> Success:
async def sd_protect_enable(ctx: wire.Context, msg: SdProtect) -> Success: async def sd_protect_enable(ctx: wire.Context, msg: SdProtect) -> Success:
salt_auth_key = device.get_sd_salt_auth_key() if storage.sd_salt.is_enabled():
if salt_auth_key is not None:
raise wire.ProcessError("SD card protection already enabled") raise wire.ProcessError("SD card protection already enabled")
# Confirm that user wants to proceed with the operation. # Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg) await require_confirm_sd_protect(ctx, msg)
# Make sure SD card is available.
await ensure_sd_card(ctx)
# Get the current PIN. # Get the current PIN.
if config.has_pin(): if config.has_pin():
pin = pin_to_int(await request_pin_ack(ctx, "Enter PIN", config.get_pin_rem())) pin = pin_to_int(await request_pin_ack(ctx, "Enter PIN", config.get_pin_rem()))
@ -57,17 +70,13 @@ async def sd_protect_enable(ctx: wire.Context, msg: SdProtect) -> Success:
pin = pin_to_int("") pin = pin_to_int("")
# Check PIN and prepare salt file. # Check PIN and prepare salt file.
salt = random.bytes(SD_SALT_LEN_BYTES) salt, salt_auth_key, salt_tag = _make_salt()
salt_auth_key = random.bytes(SD_SALT_AUTH_KEY_LEN_BYTES) await _set_salt(ctx, salt, salt_tag)
salt_tag = hmac.new(salt_auth_key, salt, sha256).digest()[
:SD_SALT_AUTH_TAG_LEN_BYTES
]
await set_sd_salt(ctx, salt, salt_tag)
if not config.change_pin(pin, pin, None, salt): if not config.change_pin(pin, pin, None, salt):
# Wrong PIN. Clean up the prepared salt file. # Wrong PIN. Clean up the prepared salt file.
try: try:
await remove_sd_salt(ctx) storage.sd_salt.remove_sd_salt()
except Exception: except Exception:
# The cleanup is not necessary for the correct functioning of # The cleanup is not necessary for the correct functioning of
# SD-protection. If it fails for any reason, we suppress the # SD-protection. If it fails for any reason, we suppress the
@ -76,16 +85,19 @@ async def sd_protect_enable(ctx: wire.Context, msg: SdProtect) -> Success:
await show_pin_invalid(ctx) await show_pin_invalid(ctx)
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
device.set_sd_salt_auth_key(salt_auth_key) storage.device.set_sd_salt_auth_key(salt_auth_key)
await show_success(ctx, ("You have successfully", "enabled SD protection.")) await show_success(ctx, ("You have successfully", "enabled SD protection."))
return Success(message="SD card protection enabled") return Success(message="SD card protection enabled")
async def sd_protect_disable(ctx: wire.Context, msg: SdProtect) -> Success: async def sd_protect_disable(ctx: wire.Context, msg: SdProtect) -> Success:
if device.get_sd_salt_auth_key() is None: if not storage.sd_salt.is_enabled():
raise wire.ProcessError("SD card protection not enabled") raise wire.ProcessError("SD card protection not enabled")
# Note that the SD card doesn't need to be accessible in order to disable SD
# protection. The cleanup will not happen in such case, but that does not matter.
# Confirm that user wants to proceed with the operation. # Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg) await require_confirm_sd_protect(ctx, msg)
@ -97,11 +109,11 @@ async def sd_protect_disable(ctx: wire.Context, msg: SdProtect) -> Success:
await show_pin_invalid(ctx) await show_pin_invalid(ctx)
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
device.set_sd_salt_auth_key(None) storage.device.set_sd_salt_auth_key(None)
try: try:
# Clean up. # Clean up.
await remove_sd_salt(ctx) storage.sd_salt.remove_sd_salt()
except Exception: except Exception:
# The cleanup is not necessary for the correct functioning of # The cleanup is not necessary for the correct functioning of
# SD-protection. If it fails for any reason, we suppress the exception, # SD-protection. If it fails for any reason, we suppress the exception,
@ -113,32 +125,31 @@ async def sd_protect_disable(ctx: wire.Context, msg: SdProtect) -> Success:
async def sd_protect_refresh(ctx: wire.Context, msg: SdProtect) -> Success: async def sd_protect_refresh(ctx: wire.Context, msg: SdProtect) -> Success:
if device.get_sd_salt_auth_key() is None: if not storage.sd_salt.is_enabled():
raise wire.ProcessError("SD card protection not enabled") raise wire.ProcessError("SD card protection not enabled")
# Confirm that user wants to proceed with the operation. # Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg) await require_confirm_sd_protect(ctx, msg)
# Make sure SD card is available.
await ensure_sd_card(ctx)
# Get the current PIN and salt from the SD card. # Get the current PIN and salt from the SD card.
pin, old_salt = await request_pin_and_sd_salt(ctx, "Enter PIN") pin, old_salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
# Check PIN and change salt. # Check PIN and change salt.
new_salt = random.bytes(SD_SALT_LEN_BYTES) new_salt, new_auth_key, new_salt_tag = _make_salt()
new_salt_auth_key = random.bytes(SD_SALT_AUTH_KEY_LEN_BYTES) await _set_salt(ctx, new_salt, new_salt_tag, stage=True)
new_salt_tag = hmac.new(new_salt_auth_key, new_salt, sha256).digest()[
:SD_SALT_AUTH_TAG_LEN_BYTES
]
await stage_sd_salt(ctx, new_salt, new_salt_tag)
if not config.change_pin(pin_to_int(pin), pin_to_int(pin), old_salt, new_salt): if not config.change_pin(pin_to_int(pin), pin_to_int(pin), old_salt, new_salt):
await show_pin_invalid(ctx) await show_pin_invalid(ctx)
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
device.set_sd_salt_auth_key(new_salt_auth_key) storage.device.set_sd_salt_auth_key(new_auth_key)
try: try:
# Clean up. # Clean up.
await commit_sd_salt(ctx) storage.sd_salt.commit_sd_salt()
except Exception: except Exception:
# If the cleanup fails, then request_sd_salt() will bring the SD card # If the cleanup fails, then request_sd_salt() will bring the SD card
# into a consistent state. We suppress the exception, because overall # into a consistent state. We suppress the exception, because overall
@ -149,7 +160,7 @@ async def sd_protect_refresh(ctx: wire.Context, msg: SdProtect) -> Success:
return Success(message="SD card protection refreshed") return Success(message="SD card protection refreshed")
def require_confirm_sd_protect(ctx: wire.Context, msg: SdProtect) -> None: def require_confirm_sd_protect(ctx: wire.Context, msg: SdProtect) -> Awaitable[None]:
if msg.operation == SdProtectOperationType.ENABLE: if msg.operation == SdProtectOperationType.ENABLE:
text = Text("SD card protection", ui.ICON_CONFIG) text = Text("SD card protection", ui.ICON_CONFIG)
text.normal( text.normal(

@ -1,10 +1,10 @@
import storage.device
from trezor import ui, wire from trezor import ui, wire
from trezor.messages import ButtonRequestType from trezor.messages import ButtonRequestType
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.storage import device as storage_device
async def set_u2f_counter(ctx, msg): async def set_u2f_counter(ctx, msg):
@ -16,6 +16,6 @@ async def set_u2f_counter(ctx, msg):
text.bold("to %d?" % msg.u2f_counter) text.bold("to %d?" % msg.u2f_counter)
await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall) await require_confirm(ctx, text, code=ButtonRequestType.ProtectCall)
storage_device.set_u2f_counter(msg.u2f_counter) storage.device.set_u2f_counter(msg.u2f_counter)
return Success(message="U2F counter set") return Success(message="U2F counter set")

@ -1,3 +1,4 @@
import storage
from trezor import ui from trezor import ui
from trezor.messages import ButtonRequestType from trezor.messages import ButtonRequestType
from trezor.messages.Success import Success from trezor.messages.Success import Success
@ -5,7 +6,6 @@ from trezor.ui.button import ButtonCancel
from trezor.ui.loader import LoaderDanger from trezor.ui.loader import LoaderDanger
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common import storage
from apps.common.confirm import require_hold_to_confirm from apps.common.confirm import require_hold_to_confirm

@ -1,5 +1,6 @@
import gc import gc
from storage.cache import get_passphrase_fprint
from trezor import log from trezor import log
from trezor.messages import MessageType from trezor.messages import MessageType
from trezor.messages.MoneroLiveRefreshFinalAck import MoneroLiveRefreshFinalAck from trezor.messages.MoneroLiveRefreshFinalAck import MoneroLiveRefreshFinalAck
@ -9,7 +10,6 @@ from trezor.messages.MoneroLiveRefreshStepAck import MoneroLiveRefreshStepAck
from trezor.messages.MoneroLiveRefreshStepRequest import MoneroLiveRefreshStepRequest from trezor.messages.MoneroLiveRefreshStepRequest import MoneroLiveRefreshStepRequest
from apps.common import paths from apps.common import paths
from apps.common.cache import get_passphrase_fprint
from apps.monero import CURVE, live_refresh_token, misc from apps.monero import CURVE, live_refresh_token, misc
from apps.monero.layout import confirms from apps.monero.layout import confirms
from apps.monero.xmr import crypto, key_image, monero from apps.monero.xmr import crypto, key_image, monero

@ -4,9 +4,9 @@ from trezor.messages.WebAuthnAddResidentCredential import WebAuthnAddResidentCre
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.storage.webauthn import store_resident_credential
from apps.webauthn.confirm import ConfirmContent, ConfirmInfo from apps.webauthn.confirm import ConfirmContent, ConfirmInfo
from apps.webauthn.credential import Fido2Credential from apps.webauthn.credential import Fido2Credential
from apps.webauthn.resident_credentials import store_resident_credential
if False: if False:
from typing import Optional from typing import Optional
@ -33,8 +33,10 @@ async def add_resident_credential(
if not msg.credential_id: if not msg.credential_id:
raise wire.ProcessError("Missing credential ID parameter.") raise wire.ProcessError("Missing credential ID parameter.")
cred = Fido2Credential.from_cred_id(msg.credential_id, None) try:
if cred is None: cred = Fido2Credential.from_cred_id(msg.credential_id, None)
except Exception:
text = Text("Import credential", ui.ICON_WRONG, ui.RED) text = Text("Import credential", ui.ICON_WRONG, ui.RED)
text.normal( text.normal(
"The credential you are", "The credential you are",
@ -43,7 +45,7 @@ async def add_resident_credential(
"authenticator.", "authenticator.",
) )
await require_confirm(ctx, text, confirm=None, cancel="Close") await require_confirm(ctx, text, confirm=None, cancel="Close")
raise wire.ActionCancelled("Cancelled") raise wire.ActionCancelled("Cancelled") from None
content = ConfirmContent(ConfirmAddCredential(cred)) content = ConfirmContent(ConfirmAddCredential(cred))
await require_confirm(ctx, content) await require_confirm(ctx, content)

@ -2,11 +2,11 @@ import ustruct
from micropython import const from micropython import const
from ubinascii import hexlify from ubinascii import hexlify
import storage.device
from trezor import log, utils from trezor import log, utils
from trezor.crypto import bip32, chacha20poly1305, hashlib, hmac, random from trezor.crypto import bip32, chacha20poly1305, hashlib, hmac, random
from apps.common import HARDENED, cbor, seed from apps.common import HARDENED, cbor, seed
from apps.common.storage import device as storage_device
if False: if False:
from typing import Optional from typing import Optional
@ -51,16 +51,14 @@ class Credential:
return None return None
def next_signature_counter(self) -> int: def next_signature_counter(self) -> int:
return storage_device.next_u2f_counter() or 0 return storage.device.next_u2f_counter() or 0
@staticmethod @staticmethod
def from_bytes(data: bytes, rp_id_hash: bytes) -> Optional["Credential"]: def from_bytes(data: bytes, rp_id_hash: bytes) -> Optional["Credential"]:
cred = Fido2Credential.from_cred_id( try:
data, rp_id_hash return Fido2Credential.from_cred_id(data, rp_id_hash)
) # type: Optional[Credential] except Exception:
if cred is None: return U2fCredential.from_key_handle(data, rp_id_hash)
cred = U2fCredential.from_key_handle(data, rp_id_hash)
return cred
# SLIP-0022: FIDO2 credential ID format for HD wallets # SLIP-0022: FIDO2 credential ID format for HD wallets
@ -83,7 +81,7 @@ class Fido2Credential(Credential):
return True return True
def generate_id(self) -> None: def generate_id(self) -> None:
self.creation_time = storage_device.next_u2f_counter() or 0 self.creation_time = storage.device.next_u2f_counter() or 0
data = cbor.encode( data = cbor.encode(
{ {
@ -111,12 +109,12 @@ class Fido2Credential(Credential):
tag = ctx.finish() tag = ctx.finish()
self.id = _CRED_ID_VERSION + iv + ciphertext + tag self.id = _CRED_ID_VERSION + iv + ciphertext + tag
@staticmethod @classmethod
def from_cred_id( def from_cred_id(
cred_id: bytes, rp_id_hash: Optional[bytes] cls, cred_id: bytes, rp_id_hash: Optional[bytes]
) -> Optional["Fido2Credential"]: ) -> "Fido2Credential":
if len(cred_id) < _CRED_ID_MIN_LENGTH or cred_id[0:4] != _CRED_ID_VERSION: if len(cred_id) < _CRED_ID_MIN_LENGTH or cred_id[0:4] != _CRED_ID_VERSION:
return None raise ValueError # invalid length or version
key = seed.derive_slip21_node_without_passphrase( key = seed.derive_slip21_node_without_passphrase(
[b"SLIP-0022", cred_id[0:4], b"Encryption key"] [b"SLIP-0022", cred_id[0:4], b"Encryption key"]
@ -130,25 +128,25 @@ class Fido2Credential(Credential):
data = ctx.decrypt(ciphertext) data = ctx.decrypt(ciphertext)
try: try:
rp_id = cbor.decode(data)[_CRED_ID_RP_ID] rp_id = cbor.decode(data)[_CRED_ID_RP_ID]
except Exception: except Exception as e:
return None raise ValueError from e # CBOR decoding failed
rp_id_hash = hashlib.sha256(rp_id).digest() rp_id_hash = hashlib.sha256(rp_id).digest()
ctx = chacha20poly1305(key, iv) ctx = chacha20poly1305(key, iv)
ctx.auth(rp_id_hash) ctx.auth(rp_id_hash)
data = ctx.decrypt(ciphertext) data = ctx.decrypt(ciphertext)
if not utils.consteq(ctx.finish(), tag): if not utils.consteq(ctx.finish(), tag):
return None raise ValueError # inauthentic ciphertext
try: try:
data = cbor.decode(data) data = cbor.decode(data)
except Exception: except Exception as e:
return None raise ValueError from e # CBOR decoding failed
if not isinstance(data, dict): if not isinstance(data, dict):
return None raise ValueError # invalid CBOR data
cred = Fido2Credential() cred = cls()
cred.rp_id = data.get(_CRED_ID_RP_ID, None) cred.rp_id = data.get(_CRED_ID_RP_ID, None)
cred.rp_id_hash = rp_id_hash cred.rp_id_hash = rp_id_hash
cred.rp_name = data.get(_CRED_ID_RP_NAME, None) cred.rp_name = data.get(_CRED_ID_RP_NAME, None)
@ -165,7 +163,7 @@ class Fido2Credential(Credential):
or not cred.check_data_types() or not cred.check_data_types()
or hashlib.sha256(cred.rp_id).digest() != rp_id_hash or hashlib.sha256(cred.rp_id).digest() != rp_id_hash
): ):
return None raise ValueError # data consistency check failed
return cred return cred

@ -3,6 +3,8 @@ import ustruct
import utime import utime
from micropython import const from micropython import const
import storage
import storage.webauthn
from trezor import config, io, log, loop, ui, utils, workflow from trezor import config, io, log, loop, ui, utils, workflow
from trezor.crypto import aes, der, hashlib, hmac, random from trezor.crypto import aes, der, hashlib, hmac, random
from trezor.crypto.curve import nist256p1 from trezor.crypto.curve import nist256p1
@ -10,14 +12,13 @@ from trezor.ui.confirm import CONFIRMED, Confirm, ConfirmPageable, Pageable
from trezor.ui.popup import Popup from trezor.ui.popup import Popup
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common import cbor, storage from apps.common import cbor
from apps.common.storage.webauthn import (
erase_resident_credentials,
get_resident_credentials,
store_resident_credential,
)
from apps.webauthn.confirm import ConfirmContent, ConfirmInfo from apps.webauthn.confirm import ConfirmContent, ConfirmInfo
from apps.webauthn.credential import Credential, Fido2Credential, U2fCredential from apps.webauthn.credential import Credential, Fido2Credential, U2fCredential
from apps.webauthn.resident_credentials import (
find_by_rp_id_hash,
store_resident_credential,
)
if __debug__: if __debug__:
from apps.debug import confirm_signal from apps.debug import confirm_signal
@ -863,7 +864,7 @@ class Fido2ConfirmReset(Fido2State):
return await confirm(text) return await confirm(text)
async def on_confirm(self) -> None: async def on_confirm(self) -> None:
erase_resident_credentials() storage.webauthn.delete_all_resident_credentials()
cmd = Cmd(self.cid, _CMD_CBOR, bytes([_ERR_NONE])) cmd = Cmd(self.cid, _CMD_CBOR, bytes([_ERR_NONE]))
await send_cmd(cmd, self.iface) await send_cmd(cmd, self.iface)
@ -1481,7 +1482,7 @@ def cbor_get_assertion(req: Cmd, dialog_mgr: DialogManager) -> Optional[Cmd]:
else: else:
# Allow list is empty. Get resident credentials. # Allow list is empty. Get resident credentials.
if _ALLOW_RESIDENT_CREDENTIALS: if _ALLOW_RESIDENT_CREDENTIALS:
cred_list = get_resident_credentials(rp_id_hash) cred_list = list(find_by_rp_id_hash(rp_id_hash))
else: else:
cred_list = [] cred_list = []
resident = True resident = True

@ -7,7 +7,7 @@ from trezor.messages.WebAuthnListResidentCredentials import (
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.storage.webauthn import get_resident_credentials from apps.webauthn import resident_credentials
async def list_resident_credentials( async def list_resident_credentials(
@ -34,6 +34,6 @@ async def list_resident_credentials(
hmac_secret=cred.hmac_secret, hmac_secret=cred.hmac_secret,
use_sign_count=cred.use_sign_count, use_sign_count=cred.use_sign_count,
) )
for cred in get_resident_credentials() for cred in resident_credentials.find_all()
] ]
return WebAuthnCredentials(creds) return WebAuthnCredentials(creds)

@ -1,3 +1,4 @@
import storage.webauthn
from trezor import wire from trezor import wire
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.messages.WebAuthnRemoveResidentCredential import ( from trezor.messages.WebAuthnRemoveResidentCredential import (
@ -5,12 +6,9 @@ from trezor.messages.WebAuthnRemoveResidentCredential import (
) )
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.storage.webauthn import (
erase_resident_credential,
get_resident_credential,
)
from apps.webauthn.confirm import ConfirmContent, ConfirmInfo from apps.webauthn.confirm import ConfirmContent, ConfirmInfo
from apps.webauthn.credential import Fido2Credential from apps.webauthn.credential import Fido2Credential
from apps.webauthn.resident_credentials import get_resident_credential
if False: if False:
from typing import Optional from typing import Optional
@ -44,5 +42,6 @@ async def remove_resident_credential(
content = ConfirmContent(ConfirmRemoveCredential(cred)) content = ConfirmContent(ConfirmRemoveCredential(cred))
await require_confirm(ctx, content) await require_confirm(ctx, content)
erase_resident_credential(msg.index) assert cred.index is not None
storage.webauthn.delete_resident_credential(cred.index)
return Success(message="Credential removed") return Success(message="Credential removed")

@ -0,0 +1,81 @@
from micropython import const
import storage.webauthn
from storage.webauthn import MAX_RESIDENT_CREDENTIALS
from apps.webauthn.credential import Fido2Credential
if False:
from typing import Iterator, Optional
RP_ID_HASH_LENGTH = const(32)
def _credential_from_data(index: int, data: bytes) -> Fido2Credential:
rp_id_hash = data[:RP_ID_HASH_LENGTH]
cred_id = data[RP_ID_HASH_LENGTH:]
cred = Fido2Credential.from_cred_id(cred_id, rp_id_hash)
cred.index = index
return cred
def find_all() -> Iterator[Fido2Credential]:
for index in range(MAX_RESIDENT_CREDENTIALS):
data = storage.webauthn.get_resident_credential(index)
if data is not None:
yield _credential_from_data(index, data)
def find_by_rp_id_hash(rp_id_hash: bytes) -> Iterator[Fido2Credential]:
for index in range(MAX_RESIDENT_CREDENTIALS):
data = storage.webauthn.get_resident_credential(index)
if data is None:
# empty slot
continue
if data[:RP_ID_HASH_LENGTH] != rp_id_hash:
# rp_id_hash mismatch
continue
yield _credential_from_data(index, data)
def get_resident_credential(index: int) -> Optional[Fido2Credential]:
if not (0 <= index < MAX_RESIDENT_CREDENTIALS):
return None
data = storage.webauthn.get_resident_credential(index)
if data is None:
return None
return _credential_from_data(index, data)
def store_resident_credential(cred: Fido2Credential) -> bool:
slot = None
for index in range(MAX_RESIDENT_CREDENTIALS):
data = storage.webauthn.get_resident_credential(index)
if data is None:
# found candidate empty slot
if slot is None:
slot = index
continue
if cred.rp_id_hash != data[:RP_ID_HASH_LENGTH]:
# slot is occupied by a different rp_id_hash
continue
stored_cred = _credential_from_data(index, data)
# If a credential for the same RP ID and user ID already exists, then overwrite it.
if stored_cred.user_id == cred.user_id:
slot = index
break
if slot is None:
return False
cred_data = cred.rp_id_hash + cred.id
storage.webauthn.set_resident_credential(slot, cred_data)
return True

@ -1,30 +1,22 @@
from trezor import config, io, log, loop, res, ui, utils import storage
import storage.device
import storage.sd_salt
from trezor import config, io, log, loop, res, ui, utils, wire
from trezor.pin import pin_to_int, show_pin_timeout from trezor.pin import pin_to_int, show_pin_timeout
from apps.common import storage
from apps.common.request_pin import PinCancelled, request_pin from apps.common.request_pin import PinCancelled, request_pin
from apps.common.sd_salt import SdProtectCancelled, request_sd_salt from apps.common.sd_salt import SdProtectCancelled, request_sd_salt
from apps.common.storage import device as storage_device
if False:
from typing import Optional
async def bootscreen() -> None: async def bootscreen() -> None:
ui.display.orientation(storage_device.get_rotation()) ui.display.orientation(storage.device.get_rotation())
salt_auth_key = storage_device.get_sd_salt_auth_key()
while True: while True:
try: try:
if salt_auth_key is not None or config.has_pin(): if storage.sd_salt.is_enabled() or config.has_pin():
await lockscreen() await lockscreen()
if salt_auth_key is not None: salt = await request_sd_salt(wire.DummyContext())
salt = await request_sd_salt(
None, salt_auth_key
) # type: Optional[bytearray]
else:
salt = None
if not config.has_pin(): if not config.has_pin():
config.unlock(pin_to_int(""), salt) config.unlock(pin_to_int(""), salt)
@ -43,12 +35,13 @@ async def bootscreen() -> None:
if __debug__: if __debug__:
log.exception(__name__, e) log.exception(__name__, e)
except Exception as e: except Exception as e:
print(e)
utils.halt(e.__class__.__name__) utils.halt(e.__class__.__name__)
async def lockscreen() -> None: async def lockscreen() -> None:
label = storage_device.get_label() label = storage.device.get_label()
image = storage_device.get_homescreen() image = storage.device.get_homescreen()
if not label: if not label:
label = "My Trezor" label = "My Trezor"
if not image: if not image:

@ -76,8 +76,8 @@ def _boot_default() -> None:
workflow.start_default(homescreen) workflow.start_default(homescreen)
import storage.recovery
from trezor import loop, wire, workflow from trezor import loop, wire, workflow
from apps.common.storage import recovery
while True: while True:
# initialize the wire codec # initialize the wire codec
@ -86,7 +86,7 @@ while True:
wire.setup(usb.iface_debug) wire.setup(usb.iface_debug)
# boot either in recovery or default mode # boot either in recovery or default mode
if recovery.is_in_progress(): if storage.recovery.is_in_progress():
_boot_recovery() _boot_recovery()
else: else:
_boot_default() _boot_default()

@ -1,8 +1,6 @@
from storage import cache, common, device
from trezor import config from trezor import config
from apps.common import cache
from apps.common.storage import common, device
def set_current_version() -> None: def set_current_version() -> None:
device.set_version(common.STORAGE_VERSION_CURRENT) device.set_version(common.STORAGE_VERSION_CURRENT)

@ -1,7 +1,6 @@
from storage.device import get_device_id
from trezor.crypto import hashlib, hmac, random from trezor.crypto import hashlib, hmac, random
from apps.common.storage.device import get_device_id
if False: if False:
from typing import Optional from typing import Optional

@ -1,12 +1,10 @@
from micropython import const from micropython import const
from ubinascii import hexlify from ubinascii import hexlify
from storage import common
from trezor.crypto import random from trezor.crypto import random
from trezor.messages import BackupType from trezor.messages import BackupType
from apps.common.sd_salt import SD_SALT_AUTH_KEY_LEN_BYTES
from apps.common.storage import common
if False: if False:
from trezor.messages.ResetDevice import EnumTypeBackupType from trezor.messages.ResetDevice import EnumTypeBackupType
from typing import Optional from typing import Optional
@ -41,6 +39,10 @@ _DEFAULT_BACKUP_TYPE = BackupType.Bip39
HOMESCREEN_MAXSIZE = 16384 HOMESCREEN_MAXSIZE = 16384
# Length of SD salt auth tag.
# Other SD-salt-related constants are in sd_salt.py
SD_SALT_AUTH_KEY_LEN_BYTES = const(16)
def is_version_stored() -> bool: def is_version_stored() -> bool:
return bool(common.get(_NAMESPACE, _VERSION)) return bool(common.get(_NAMESPACE, _VERSION))

@ -1,9 +1,8 @@
from micropython import const from micropython import const
from storage import common, recovery_shares
from trezor.crypto import slip39 from trezor.crypto import slip39
from apps.common.storage import common, recovery_shares
# Namespace: # Namespace:
_NAMESPACE = common.APP_RECOVERY _NAMESPACE = common.APP_RECOVERY

@ -1,7 +1,6 @@
from storage import common
from trezor.crypto import slip39 from trezor.crypto import slip39
from apps.common.storage import common
if False: if False:
from typing import List, Optional from typing import List, Optional

@ -0,0 +1,160 @@
from micropython import const
import storage.device
from trezor import io
from trezor.crypto import hmac
from trezor.crypto.hashlib import sha256
from trezor.utils import consteq
if False:
from typing import Optional
SD_CARD_HOT_SWAPPABLE = False
SD_SALT_LEN_BYTES = const(32)
SD_SALT_AUTH_TAG_LEN_BYTES = const(16)
class SdSaltMismatch(Exception):
pass
def is_enabled() -> bool:
return storage.device.get_sd_salt_auth_key() is not None
def compute_auth_tag(salt: bytes, auth_key: bytes) -> bytes:
digest = hmac.new(auth_key, salt, sha256).digest()
return digest[:SD_SALT_AUTH_TAG_LEN_BYTES]
def _get_device_dir() -> str:
return "/trezor/device_{}".format(storage.device.get_device_id().lower())
def _get_salt_path(new: bool = False) -> str:
return "{}/salt{}".format(_get_device_dir(), ".new" if new else "")
def _load_salt(fs: io.FatFS, auth_key: bytes, path: str) -> Optional[bytearray]:
# Load the salt file if it exists.
try:
with fs.open(path, "r") as f:
salt = bytearray(SD_SALT_LEN_BYTES)
stored_tag = bytearray(SD_SALT_AUTH_TAG_LEN_BYTES)
f.read(salt)
f.read(stored_tag)
except OSError:
return None
# Check the salt's authentication tag.
computed_tag = compute_auth_tag(salt, auth_key)
if not consteq(computed_tag, stored_tag):
return None
return salt
def load_sd_salt() -> Optional[bytearray]:
salt_auth_key = storage.device.get_sd_salt_auth_key()
if salt_auth_key is None:
return None
sd = io.SDCard()
if not sd.power(True):
raise OSError
salt_path = _get_salt_path()
new_salt_path = _get_salt_path(new=True)
try:
fs = io.FatFS()
try:
fs.mount()
except OSError as e:
# SD card is probably not formatted. For purposes of loading SD salt, this
# is identical to having the wrong card in.
raise SdSaltMismatch from e
salt = _load_salt(fs, salt_auth_key, salt_path)
if salt is not None:
return salt
# Check if there is a new salt.
salt = _load_salt(fs, salt_auth_key, new_salt_path)
if salt is None:
# No valid salt file on this SD card.
raise SdSaltMismatch
# Normal salt file does not exist, but new salt file exists. That means that
# SD salt regeneration was interrupted earlier. Bring into consistent state.
# TODO Possibly overwrite salt file with random data.
try:
fs.unlink(salt_path)
except OSError:
pass
# fs.rename can fail with a write error, which falls through as an OSError.
# This should be handled in calling code, by allowing the user to retry.
fs.rename(new_salt_path, salt_path)
return salt
finally:
fs.unmount()
sd.power(False)
def set_sd_salt(salt: bytes, salt_tag: bytes, stage: bool = False) -> None:
salt_path = _get_salt_path(stage)
sd = io.SDCard()
if not sd.power(True):
raise OSError
try:
fs = io.FatFS()
fs.mount()
fs.mkdir("/trezor", True)
fs.mkdir(_get_device_dir(), True)
with fs.open(salt_path, "w") as f:
f.write(salt)
f.write(salt_tag)
finally:
fs.unmount()
sd.power(False)
def commit_sd_salt() -> None:
salt_path = _get_salt_path(new=False)
new_salt_path = _get_salt_path(new=True)
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
raise OSError
try:
fs.mount()
# TODO Possibly overwrite salt file with random data.
try:
fs.unlink(salt_path)
except OSError:
pass
fs.rename(new_salt_path, salt_path)
finally:
fs.unmount()
sd.power(False)
def remove_sd_salt() -> None:
salt_path = _get_salt_path()
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
raise OSError
try:
fs.mount()
# TODO Possibly overwrite salt file with random data.
fs.unlink(salt_path)
finally:
fs.unmount()
sd.power(False)

@ -0,0 +1,37 @@
from micropython import const
from storage import common
if False:
from typing import Optional
_RESIDENT_CREDENTIAL_START_KEY = const(1)
MAX_RESIDENT_CREDENTIALS = const(100)
def get_resident_credential(index: int) -> Optional[bytes]:
if not (0 <= index < MAX_RESIDENT_CREDENTIALS):
raise ValueError # invalid credential index
return common.get(common.APP_WEBAUTHN, index + _RESIDENT_CREDENTIAL_START_KEY)
def set_resident_credential(index: int, data: bytes) -> None:
if not (0 <= index < MAX_RESIDENT_CREDENTIALS):
raise ValueError # invalid credential index
common.set(common.APP_WEBAUTHN, index + _RESIDENT_CREDENTIAL_START_KEY, data)
def delete_resident_credential(index: int) -> None:
if not (0 <= index < MAX_RESIDENT_CREDENTIALS):
raise ValueError # invalid credential index
common.delete(common.APP_WEBAUTHN, index + _RESIDENT_CREDENTIAL_START_KEY)
def delete_all_resident_credentials() -> None:
for i in range(MAX_RESIDENT_CREDENTIALS):
common.delete(common.APP_WEBAUTHN, i + _RESIDENT_CREDENTIAL_START_KEY)

@ -134,6 +134,9 @@ class DummyContext:
return await loop.race(*tasks) return await loop.race(*tasks)
DUMMY_CONTEXT = DummyContext()
class Context: class Context:
def __init__(self, iface: WireInterface, sid: int) -> None: def __init__(self, iface: WireInterface, sid: int) -> None:
self.iface = iface self.iface = iface

@ -1,7 +1,6 @@
from storage.device import get_device_id
from trezor import io, utils from trezor import io, utils
from apps.common.storage.device import get_device_id
# fmt: off # fmt: off
# interface used for trezor wire protocol # interface used for trezor wire protocol

Loading…
Cancel
Save