diff --git a/core/embed/extmod/modtrezorconfig/modtrezorconfig.c b/core/embed/extmod/modtrezorconfig/modtrezorconfig.c index 3c6356883b..b8755a4093 100644 --- a/core/embed/extmod/modtrezorconfig/modtrezorconfig.c +++ b/core/embed/extmod/modtrezorconfig/modtrezorconfig.c @@ -67,41 +67,39 @@ STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) { STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, mod_trezorconfig_init); -/// def unlock(pin: int, ext_salt: Optional[bytes] = None) -> bool: +/// def unlock(pin: int, ext_salt: Optional[bytes]) -> bool: /// """ /// Attempts to unlock the storage with the given PIN and external salt. /// Returns True on success, False on failure. /// """ -STATIC mp_obj_t mod_trezorconfig_unlock(size_t n_args, const mp_obj_t *args) { - uint32_t pin = trezor_obj_get_uint(args[0]); - const uint8_t *ext_salt = NULL; - if (n_args > 1 && args[1] != mp_const_none) { - mp_buffer_info_t ext_salt_b; - mp_get_buffer_raise(args[1], &ext_salt_b, MP_BUFFER_READ); +STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin, mp_obj_t ext_salt) { + uint32_t pin_i = trezor_obj_get_uint(pin); + mp_buffer_info_t ext_salt_b; + ext_salt_b.buf = NULL; + if (ext_salt != mp_const_none) { + mp_get_buffer_raise(ext_salt, &ext_salt_b, MP_BUFFER_READ); if (ext_salt_b.len != EXTERNAL_SALT_SIZE) mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt."); - ext_salt = ext_salt_b.buf; } - if (sectrue != storage_unlock(pin, ext_salt)) { + if (sectrue != storage_unlock(pin_i, ext_salt_b.buf)) { return mp_const_false; } return mp_const_true; } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_unlock_obj, 1, 2, - mod_trezorconfig_unlock); +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_unlock_obj, + mod_trezorconfig_unlock); -/// def check_pin(pin: int, ext_salt: Optional[bytes] = None) -> bool: +/// def check_pin(pin: int, ext_salt: Optional[bytes]) -> bool: /// """ /// Check the given PIN with the given external salt. /// Returns True on success, False on failure. /// """ -STATIC mp_obj_t mod_trezorconfig_check_pin(size_t n_args, - const mp_obj_t *args) { - return mod_trezorconfig_unlock(n_args, args); +STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin, mp_obj_t ext_salt) { + return mod_trezorconfig_unlock(pin, ext_salt); } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_check_pin_obj, 1, 2, - mod_trezorconfig_check_pin); +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_check_pin_obj, + mod_trezorconfig_check_pin); /// def lock() -> None: /// """ @@ -140,8 +138,8 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_get_pin_rem_obj, /// def change_pin( /// oldpin: int, /// newpin: int, -/// old_ext_salt: Optional[bytes] = None, -/// new_ext_salt: Optional[bytes] = None, +/// old_ext_salt: Optional[bytes], +/// new_ext_salt: Optional[bytes], /// ) -> bool: /// """ /// Change PIN and external salt. Returns True on success, False on failure. @@ -152,14 +150,14 @@ STATIC mp_obj_t mod_trezorconfig_change_pin(size_t n_args, uint32_t newpin = trezor_obj_get_uint(args[1]); mp_buffer_info_t ext_salt_b; const uint8_t *old_ext_salt = NULL; - if (n_args > 2 && args[2] != mp_const_none) { + if (args[2] != mp_const_none) { mp_get_buffer_raise(args[2], &ext_salt_b, MP_BUFFER_READ); if (ext_salt_b.len != EXTERNAL_SALT_SIZE) mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt."); old_ext_salt = ext_salt_b.buf; } const uint8_t *new_ext_salt = NULL; - if (n_args > 3 && args[3] != mp_const_none) { + if (args[3] != mp_const_none) { mp_get_buffer_raise(args[3], &ext_salt_b, MP_BUFFER_READ); if (ext_salt_b.len != EXTERNAL_SALT_SIZE) mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt."); @@ -172,7 +170,7 @@ STATIC mp_obj_t mod_trezorconfig_change_pin(size_t n_args, } return mp_const_true; } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_change_pin_obj, 2, +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_change_pin_obj, 4, 4, mod_trezorconfig_change_pin); /// def get(app: int, key: int, public: bool = False) -> Optional[bytes]: diff --git a/core/mocks/generated/trezorconfig.pyi b/core/mocks/generated/trezorconfig.pyi index ae3742467e..5b8645b9ba 100644 --- a/core/mocks/generated/trezorconfig.pyi +++ b/core/mocks/generated/trezorconfig.pyi @@ -12,7 +12,7 @@ def init( # extmod/modtrezorconfig/modtrezorconfig.c -def unlock(pin: int, ext_salt: Optional[bytes] = None) -> bool: +def unlock(pin: int, ext_salt: Optional[bytes]) -> bool: """ Attempts to unlock the storage with the given PIN and external salt. Returns True on success, False on failure. @@ -20,7 +20,7 @@ def unlock(pin: int, ext_salt: Optional[bytes] = None) -> bool: # extmod/modtrezorconfig/modtrezorconfig.c -def check_pin(pin: int, ext_salt: Optional[bytes] = None) -> bool: +def check_pin(pin: int, ext_salt: Optional[bytes]) -> bool: """ Check the given PIN with the given external salt. Returns True on success, False on failure. @@ -52,8 +52,8 @@ def get_pin_rem() -> int: def change_pin( oldpin: int, newpin: int, - old_ext_salt: Optional[bytes] = None, - new_ext_salt: Optional[bytes] = None, + old_ext_salt: Optional[bytes], + new_ext_salt: Optional[bytes], ) -> bool: """ Change PIN and external salt. Returns True on success, False on failure. diff --git a/core/src/apps/common/sd_salt.py b/core/src/apps/common/sd_salt.py new file mode 100644 index 0000000000..94b81fa811 --- /dev/null +++ b/core/src/apps/common/sd_salt.py @@ -0,0 +1,193 @@ +from micropython import const + +from trezor import io, ui, wire +from trezor.crypto import hmac +from trezor.crypto.hashlib import sha256 +from trezor.ui.confirm import Confirm +from trezor.ui.text import Text +from trezor.utils import consteq + +from apps.common import storage +from apps.common.confirm import require_confirm + +if False: + from typing import Optional + + +class SdProtectCancelled(Exception): + pass + + +SD_SALT_LEN_BYTES = const(32) +SD_SALT_AUTH_TAG_LEN_BYTES = const(16) +SD_SALT_AUTH_KEY_LEN_BYTES = const(16) + + +async def wrong_card_dialog(ctx: Optional[wire.Context]) -> None: + text = Text("SD card protection", ui.ICON_WRONG) + text.bold("Wrong SD card.") + text.br_half() + text.normal("Please unplug the", "device and insert a", "different card.") + if ctx is None: + await Confirm(text, confirm=None) + else: + await require_confirm(ctx, text, confirm=None) + + +async def insert_card_dialog(ctx: Optional[wire.Context]) -> None: + text = Text("SD card protection") + text.bold("SD card required.") + text.br_half() + text.normal("Please unplug the", "device and insert your", "SD card.") + if ctx is None: + await Confirm(text, confirm=None) + else: + await require_confirm(ctx, text, confirm=None) + + +async def request_sd_salt( + ctx: Optional[wire.Context], salt_auth_key: bytes +) -> bytearray: + device_dir = "/trezor/device_%s" % storage.device.get_device_id() + salt_path = "%s/salt" % device_dir + new_salt_path = "%s/salt.new" % device_dir + + sd = io.SDCard() + fs = io.FatFS() + if not sd.power(True): + await insert_card_dialog(ctx) + raise SdProtectCancelled + + try: + fs.mount() + + # Load salt if it exists. + try: + with fs.open(salt_path, "r") as f: + salt = bytearray(SD_SALT_LEN_BYTES) # type: Optional[bytearray] + salt_tag = bytearray(SD_SALT_AUTH_TAG_LEN_BYTES) + f.read(salt) + f.read(salt_tag) + except OSError: + salt = None + + if salt is not None and consteq( + hmac.new(salt_auth_key, salt, sha256).digest()[:SD_SALT_AUTH_TAG_LEN_BYTES], + salt_tag, + ): + return salt + + # Load salt.new if it exists. + try: + with fs.open(new_salt_path, "r") as f: + new_salt = bytearray(SD_SALT_LEN_BYTES) # type: Optional[bytearray] + new_salt_tag = bytearray(SD_SALT_AUTH_TAG_LEN_BYTES) + f.read(new_salt) + f.read(new_salt_tag) + except OSError: + new_salt = None + + if new_salt is not None and consteq( + hmac.new(salt_auth_key, new_salt, sha256).digest()[ + :SD_SALT_AUTH_TAG_LEN_BYTES + ], + new_salt_tag, + ): + # SD salt regeneration was interrupted earlier. Bring into consistent state. + # TODO Possibly overwrite salt file with random data. + try: + fs.unlink(salt_path) + except OSError: + pass + fs.rename(new_salt_path, salt_path) + return new_salt + finally: + fs.unmount() + sd.power(False) + + await wrong_card_dialog(ctx) + raise SdProtectCancelled + + +async def set_sd_salt( + ctx: Optional[wire.Context], salt: bytes, salt_tag: bytes, filename: str = "salt" +) -> None: + device_dir = "/trezor/device_%s" % storage.device.get_device_id() + salt_path = "%s/%s" % (device_dir, filename) + + sd = io.SDCard() + fs = io.FatFS() + if not sd.power(True): + await insert_card_dialog(ctx) + raise SdProtectCancelled + + try: + fs.mount() + + try: + fs.mkdir("/trezor") + except OSError: + # Directory already exists. + pass + + try: + fs.mkdir(device_dir) + except OSError: + # Directory already exists. + pass + + with fs.open(salt_path, "w") as f: + f.write(salt) + f.write(salt_tag) + finally: + fs.unmount() + sd.power(False) + + +async def stage_sd_salt( + ctx: Optional[wire.Context], salt: bytes, salt_tag: bytes +) -> None: + await set_sd_salt(ctx, salt, salt_tag, "salt.new") + + +async def commit_sd_salt(ctx: Optional[wire.Context]) -> None: + device_dir = "/trezor/device_%s" % storage.device.get_device_id() + salt_path = "%s/salt" % device_dir + new_salt_path = "%s/salt.new" % device_dir + + sd = io.SDCard() + fs = io.FatFS() + if not sd.power(True): + await insert_card_dialog(ctx) + raise SdProtectCancelled + + try: + fs.mount() + # TODO Possibly overwrite salt file with random data. + try: + fs.unlink(salt_path) + except OSError: + pass + fs.rename(new_salt_path, salt_path) + finally: + fs.unmount() + sd.power(False) + + +async def remove_sd_salt(ctx: Optional[wire.Context]) -> None: + device_dir = "/trezor/device_%s" % storage.device.get_device_id() + salt_path = "%s/salt" % device_dir + + sd = io.SDCard() + fs = io.FatFS() + if not sd.power(True): + await insert_card_dialog(ctx) + raise SdProtectCancelled + + try: + fs.mount() + # TODO Possibly overwrite salt file with random data. + fs.unlink(salt_path) + finally: + fs.unmount() + sd.power(False) diff --git a/core/src/apps/common/storage/common.py b/core/src/apps/common/storage/common.py index af188f1181..f00c3ea363 100644 --- a/core/src/apps/common/storage/common.py +++ b/core/src/apps/common/storage/common.py @@ -28,8 +28,8 @@ def get(app: int, key: int, public: bool = False) -> Optional[bytes]: return config.get(app, key, public) -def delete(app: int, key: int) -> None: - config.delete(app, key) +def delete(app: int, key: int, public: bool = False) -> None: + config.delete(app, key, public) def set_true_or_delete(app: int, key: int, value: bool) -> None: diff --git a/core/src/apps/common/storage/device.py b/core/src/apps/common/storage/device.py index 693e82579d..a9104db866 100644 --- a/core/src/apps/common/storage/device.py +++ b/core/src/apps/common/storage/device.py @@ -3,6 +3,7 @@ from ubinascii import hexlify from trezor.crypto import random +from apps.common.sd_salt import SD_SALT_AUTH_KEY_LEN_BYTES from apps.common.storage import common if False: @@ -31,6 +32,7 @@ _MNEMONIC_TYPE = const(0x0E) # int _ROTATION = const(0x0F) # int _SLIP39_IDENTIFIER = const(0x10) # bool _SLIP39_ITERATION_EXPONENT = const(0x11) # int +_SD_SALT_AUTH_KEY = const(0x12) # bytes # fmt: on HOMESCREEN_MAXSIZE = 16384 @@ -234,3 +236,25 @@ def get_slip39_iteration_exponent() -> Optional[int]: The device's actual SLIP-39 iteration exponent used in passphrase derivation. """ return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) + + +def get_sd_salt_auth_key() -> Optional[bytes]: + """ + The key used to check the authenticity of the SD card salt. + """ + auth_key = common.get(_NAMESPACE, _SD_SALT_AUTH_KEY, public=True) + if auth_key is not None and len(auth_key) != SD_SALT_AUTH_KEY_LEN_BYTES: + raise ValueError + return auth_key + + +def set_sd_salt_auth_key(auth_key: Optional[bytes]) -> None: + """ + The key used to check the authenticity of the SD card salt. + """ + if auth_key is not None: + if len(auth_key) != SD_SALT_AUTH_KEY_LEN_BYTES: + raise ValueError + return common.set(_NAMESPACE, _SD_SALT_AUTH_KEY, auth_key, public=True) + else: + return common.delete(_NAMESPACE, _SD_SALT_AUTH_KEY, public=True) diff --git a/core/src/apps/management/__init__.py b/core/src/apps/management/__init__.py index e1bbbc3869..51eb80ced0 100644 --- a/core/src/apps/management/__init__.py +++ b/core/src/apps/management/__init__.py @@ -14,3 +14,4 @@ def boot() -> None: wire.add(MessageType.ApplyFlags, __name__, "apply_flags") wire.add(MessageType.ChangePin, __name__, "change_pin") wire.add(MessageType.SetU2FCounter, __name__, "set_u2f_counter") + wire.add(MessageType.SdProtect, __name__, "sd_protect") diff --git a/core/src/apps/management/change_pin.py b/core/src/apps/management/change_pin.py index 793eef62af..b8dbdf61a0 100644 --- a/core/src/apps/management/change_pin.py +++ b/core/src/apps/management/change_pin.py @@ -9,26 +9,25 @@ from trezor.ui.text import Text from apps.common.confirm import require_confirm from apps.common.request_pin import PinCancelled, request_pin +from apps.common.sd_salt import request_sd_salt +from apps.common.storage import device if False: - from typing import Any + from typing import Any, Optional, Tuple from trezor.messages.ChangePin import ChangePin async def change_pin(ctx: wire.Context, msg: ChangePin) -> Success: - # confirm that user wants to change the pin await require_confirm_change_pin(ctx, msg) - # get current pin, return failure if invalid - if config.has_pin(): - curpin = await request_pin_ack(ctx, "Enter old PIN", config.get_pin_rem()) - # if removing, defer check to change_pin() - if not msg.remove: - if not config.check_pin(pin_to_int(curpin)): - raise wire.PinInvalid("PIN invalid") - else: - curpin = "" + # get old pin + curpin, salt = await request_pin_and_sd_salt(ctx, "Enter old PIN") + + # if changing pin, pre-check the entered pin before getting new pin + if curpin and not msg.remove: + if not config.check_pin(pin_to_int(curpin), salt): + raise wire.PinInvalid("PIN invalid") # get new pin if not msg.remove: @@ -37,7 +36,7 @@ async def change_pin(ctx: wire.Context, msg: ChangePin) -> Success: newpin = "" # write into storage - if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin)): + if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), salt, salt): raise wire.PinInvalid("PIN invalid") if newpin: @@ -77,6 +76,23 @@ async def request_pin_confirm(ctx: wire.Context, *args: Any, **kwargs: Any) -> s await pin_mismatch() +async def request_pin_and_sd_salt( + ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True +) -> Tuple[str, Optional[bytearray]]: + salt_auth_key = device.get_sd_salt_auth_key() + if salt_auth_key is not None: + salt = await request_sd_salt(ctx, salt_auth_key) # type: Optional[bytearray] + else: + salt = None + + if config.has_pin(): + pin = await request_pin_ack(ctx, prompt, config.get_pin_rem(), allow_cancel) + else: + pin = "" + + return pin, salt + + async def request_pin_ack(ctx: wire.Context, *args: Any, **kwargs: Any) -> str: try: await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck) diff --git a/core/src/apps/management/load_device.py b/core/src/apps/management/load_device.py index 7fabba9063..b7b8aeeb7f 100644 --- a/core/src/apps/management/load_device.py +++ b/core/src/apps/management/load_device.py @@ -55,6 +55,6 @@ async def load_device(ctx, msg): use_passphrase=msg.passphrase_protection, label=msg.label ) if msg.pin: - config.change_pin(pin_to_int(""), pin_to_int(msg.pin)) + config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None, None) return Success(message="Device loaded") diff --git a/core/src/apps/management/recovery_device/__init__.py b/core/src/apps/management/recovery_device/__init__.py index fa13558ed6..8342c5b238 100644 --- a/core/src/apps/management/recovery_device/__init__.py +++ b/core/src/apps/management/recovery_device/__init__.py @@ -6,7 +6,7 @@ from trezor.ui.text import Text from apps.common import storage from apps.common.confirm import require_confirm -from apps.management.change_pin import request_pin_ack, request_pin_confirm +from apps.management.change_pin import request_pin_and_sd_salt, request_pin_confirm from apps.management.recovery_device.homescreen import recovery_process if False: @@ -24,13 +24,10 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: await _continue_dialog(ctx, msg) - # for dry run pin needs to entered + # for dry run pin needs to be entered if msg.dry_run: - if config.has_pin(): - curpin = await request_pin_ack(ctx, "Enter PIN", config.get_pin_rem()) - else: - curpin = "" - if not config.check_pin(pin_to_int(curpin)): + curpin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN") + if not config.check_pin(pin_to_int(curpin), salt): raise wire.PinInvalid("PIN invalid") # set up pin if requested @@ -38,7 +35,7 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success: if msg.dry_run: raise wire.ProcessError("Can't setup PIN during dry_run recovery.") newpin = await request_pin_confirm(ctx, allow_cancel=False) - config.change_pin(pin_to_int(""), pin_to_int(newpin)) + config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None) if msg.u2f_counter: storage.device.set_u2f_counter(msg.u2f_counter) diff --git a/core/src/apps/management/reset_device.py b/core/src/apps/management/reset_device.py index 7770b7234c..1c96c09495 100644 --- a/core/src/apps/management/reset_device.py +++ b/core/src/apps/management/reset_device.py @@ -71,7 +71,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success: await backup_bip39_wallet(ctx, secret) # write PIN into storage - if not config.change_pin(pin_to_int(""), pin_to_int(newpin)): + if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None): raise wire.ProcessError("Could not change PIN") # write settings and master secret into storage diff --git a/core/src/apps/management/sd_protect.py b/core/src/apps/management/sd_protect.py new file mode 100644 index 0000000000..fe9c2da36a --- /dev/null +++ b/core/src/apps/management/sd_protect.py @@ -0,0 +1,169 @@ +from trezor import config, ui, wire +from trezor.crypto import hmac, random +from trezor.crypto.hashlib import sha256 +from trezor.messages import SdProtectOperationType +from trezor.messages.Success import Success +from trezor.pin import pin_to_int +from trezor.ui.text import Text + +from apps.common.confirm import require_confirm +from apps.common.sd_salt import ( + SD_SALT_AUTH_KEY_LEN_BYTES, + SD_SALT_AUTH_TAG_LEN_BYTES, + SD_SALT_LEN_BYTES, + commit_sd_salt, + remove_sd_salt, + set_sd_salt, + stage_sd_salt, +) +from apps.common.storage import device, is_initialized +from apps.management.change_pin import request_pin_ack, request_pin_and_sd_salt + +if False: + from trezor.messages.SdProtect import SdProtect + + +async def sd_protect(ctx: wire.Context, msg: SdProtect) -> Success: + if not is_initialized(): + raise wire.ProcessError("Device is not initialized") + + if msg.operation == SdProtectOperationType.ENABLE: + return await sd_protect_enable(ctx, msg) + elif msg.operation == SdProtectOperationType.DISABLE: + return await sd_protect_disable(ctx, msg) + elif msg.operation == SdProtectOperationType.REFRESH: + return await sd_protect_refresh(ctx, msg) + else: + raise wire.ProcessError("Unknown operation") + + +async def sd_protect_enable(ctx: wire.Context, msg: SdProtect) -> Success: + salt_auth_key = device.get_sd_salt_auth_key() + if salt_auth_key is not None: + raise wire.ProcessError("SD card protection already enabled") + + # Confirm that user wants to proceed with the operation. + await require_confirm_sd_protect(ctx, msg) + + # Get the current PIN. + if config.has_pin(): + pin = pin_to_int(await request_pin_ack(ctx, "Enter PIN", config.get_pin_rem())) + else: + pin = pin_to_int("") + + # Check PIN and prepare salt file. + salt = random.bytes(SD_SALT_LEN_BYTES) + salt_auth_key = random.bytes(SD_SALT_AUTH_KEY_LEN_BYTES) + salt_tag = hmac.new(salt_auth_key, salt, sha256).digest()[ + :SD_SALT_AUTH_TAG_LEN_BYTES + ] + try: + await set_sd_salt(ctx, salt, salt_tag) + except Exception: + raise wire.ProcessError("Failed to write to SD card") + + if not config.change_pin(pin, pin, None, salt): + # Wrong PIN. Clean up the prepared salt file. + try: + await remove_sd_salt(ctx) + except Exception: + # The cleanup is not necessary for the correct functioning of + # SD-protection. If it fails for any reason, we suppress the + # exception, because primarily we need to raise wire.PinInvalid. + pass + raise wire.PinInvalid("PIN invalid") + + device.set_sd_salt_auth_key(salt_auth_key) + + return Success(message="SD card protection enabled") + + +async def sd_protect_disable(ctx: wire.Context, msg: SdProtect) -> Success: + if device.get_sd_salt_auth_key() is None: + raise wire.ProcessError("SD card protection not enabled") + + # Confirm that user wants to proceed with the operation. + await require_confirm_sd_protect(ctx, msg) + + # Get the current PIN and salt from the SD card. + pin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN") + + # Check PIN and remove salt. + if not config.change_pin(pin_to_int(pin), pin_to_int(pin), salt, None): + raise wire.PinInvalid("PIN invalid") + + device.set_sd_salt_auth_key(None) + + try: + # Clean up. + await remove_sd_salt(ctx) + except Exception: + # The cleanup is not necessary for the correct functioning of + # SD-protection. If it fails for any reason, we suppress the exception, + # because overall SD-protection was successfully disabled. + pass + + return Success(message="SD card protection disabled") + + +async def sd_protect_refresh(ctx: wire.Context, msg: SdProtect) -> Success: + if device.get_sd_salt_auth_key() is None: + raise wire.ProcessError("SD card protection not enabled") + + # Confirm that user wants to proceed with the operation. + await require_confirm_sd_protect(ctx, msg) + + # Get the current PIN and salt from the SD card. + pin, old_salt = await request_pin_and_sd_salt(ctx, "Enter PIN") + + # Check PIN and change salt. + new_salt = random.bytes(SD_SALT_LEN_BYTES) + new_salt_auth_key = random.bytes(SD_SALT_AUTH_KEY_LEN_BYTES) + new_salt_tag = hmac.new(new_salt_auth_key, new_salt, sha256).digest()[ + :SD_SALT_AUTH_TAG_LEN_BYTES + ] + try: + await stage_sd_salt(ctx, new_salt, new_salt_tag) + except Exception: + raise wire.ProcessError("Failed to write to SD card") + + if not config.change_pin(pin_to_int(pin), pin_to_int(pin), old_salt, new_salt): + raise wire.PinInvalid("PIN invalid") + + device.set_sd_salt_auth_key(new_salt_auth_key) + + try: + # Clean up. + await commit_sd_salt(ctx) + except Exception: + # If the cleanup fails, then request_sd_salt() will bring the SD card + # into a consistent state. We suppress the exception, because overall + # SD-protection was successfully refreshed. + pass + + return Success(message="SD card protection refreshed") + + +def require_confirm_sd_protect(ctx: wire.Context, msg: SdProtect) -> None: + if msg.operation == SdProtectOperationType.ENABLE: + text = Text("SD card protection", ui.ICON_CONFIG) + text.normal( + "Do you really want to", "secure your device with", "SD card protection?" + ) + elif msg.operation == SdProtectOperationType.DISABLE: + text = Text("SD card protection", ui.ICON_CONFIG) + text.normal( + "Do you really want to", "remove SD card", "protection from your", "device?" + ) + elif msg.operation == SdProtectOperationType.REFRESH: + text = Text("SD card protection", ui.ICON_CONFIG) + text.normal( + "Do you really want to", + "replace the current", + "SD card secret with a", + "newly generated one?", + ) + else: + raise wire.ProcessError("Unknown operation") + + return require_confirm(ctx, text) diff --git a/core/src/boot.py b/core/src/boot.py index f6c414482f..cbcf310659 100644 --- a/core/src/boot.py +++ b/core/src/boot.py @@ -3,21 +3,38 @@ from trezor.pin import pin_to_int, show_pin_timeout from apps.common import storage from apps.common.request_pin import request_pin +from apps.common.sd_salt import request_sd_salt +from apps.common.storage import device + +if False: + from typing import Optional async def bootscreen() -> None: ui.display.orientation(storage.device.get_rotation()) + salt_auth_key = device.get_sd_salt_auth_key() + while True: try: + if salt_auth_key is not None or config.has_pin(): + await lockscreen() + + if salt_auth_key is not None: + salt = await request_sd_salt( + None, salt_auth_key + ) # type: Optional[bytearray] + else: + salt = None + if not config.has_pin(): - config.unlock(pin_to_int("")) + config.unlock(pin_to_int(""), salt) storage.init_unlocked() return - await lockscreen() + label = "Enter your PIN" while True: pin = await request_pin(label, config.get_pin_rem()) - if config.unlock(pin_to_int(pin)): + if config.unlock(pin_to_int(pin), salt): storage.init_unlocked() return else: diff --git a/core/tests/test_trezor.config.py b/core/tests/test_trezor.config.py index 416ec30594..a6879996d2 100644 --- a/core/tests/test_trezor.config.py +++ b/core/tests/test_trezor.config.py @@ -27,7 +27,7 @@ class TestConfig(unittest.TestCase): def test_wipe(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int(''), None), True) config.set(1, 1, b'hello') config.set(1, 2, b'world') v0 = config.get(1, 1) @@ -44,7 +44,7 @@ class TestConfig(unittest.TestCase): for _ in range(128): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int(''), None), True) appid, key = random_entry() value = random.bytes(16) config.set(appid, key, value) @@ -58,7 +58,7 @@ class TestConfig(unittest.TestCase): def test_public(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int(''), None), True) appid, key = random_entry() @@ -84,25 +84,59 @@ class TestConfig(unittest.TestCase): def test_change_pin(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int(''), None), True) with self.assertRaises(RuntimeError): config.set(PINAPP, PINKEY, b'value') - self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666')), False) - self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), True) + self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666'), None, None), False) + self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None, None), True) self.assertEqual(config.get(PINAPP, PINKEY), None) config.set(1, 1, b'value') config.init() - self.assertEqual(config.unlock(pin_to_int('000')), True) - config.change_pin(pin_to_int('000'), pin_to_int('')) + self.assertEqual(config.unlock(pin_to_int('000'), None), True) + config.change_pin(pin_to_int('000'), pin_to_int(''), None, None) config.init() - self.assertEqual(config.unlock(pin_to_int('000')), False) - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int('000'), None), False) + self.assertEqual(config.unlock(pin_to_int(''), None), True) + self.assertEqual(config.get(1, 1), b'value') + + def test_change_sd_salt(self): + salt1 = b"0123456789abcdef0123456789abcdef" + salt2 = b"0123456789ABCDEF0123456789ABCDEF" + + # Enable PIN and SD salt. + config.init() + config.wipe() + self.assertTrue(config.unlock(pin_to_int(''), None)) + config.set(1, 1, b'value') + self.assertFalse(config.change_pin(pin_to_int(''), pin_to_int(''), salt1, None)) + self.assertTrue(config.change_pin(pin_to_int(''), pin_to_int('000'), None, salt1)) + self.assertEqual(config.get(1, 1), b'value') + + # Disable PIN and change SD salt. + config.init() + self.assertFalse(config.unlock(pin_to_int('000'), None)) + self.assertIsNone(config.get(1, 1)) + self.assertTrue(config.unlock(pin_to_int('000'), salt1)) + self.assertTrue(config.change_pin(pin_to_int('000'), pin_to_int(''), salt1, salt2)) + self.assertEqual(config.get(1, 1), b'value') + + # Disable SD salt. + config.init() + self.assertFalse(config.unlock(pin_to_int('000'), salt2)) + self.assertIsNone(config.get(1, 1)) + self.assertTrue(config.unlock(pin_to_int(''), salt2)) + self.assertTrue(config.change_pin(pin_to_int(''), pin_to_int(''), salt2, None)) + self.assertEqual(config.get(1, 1), b'value') + + # Check that PIN and SD salt are disabled. + config.init() + self.assertTrue(config.unlock(pin_to_int(''), None)) self.assertEqual(config.get(1, 1), b'value') def test_set_get(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int(''), None), True) for _ in range(32): appid, key = random_entry() value = random.bytes(128) @@ -113,7 +147,7 @@ class TestConfig(unittest.TestCase): def test_compact(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int(''), None), True) appid, key = 1, 1 for _ in range(259): value = random.bytes(259) @@ -124,7 +158,7 @@ class TestConfig(unittest.TestCase): def test_get_default(self): config.init() config.wipe() - self.assertEqual(config.unlock(pin_to_int('')), True) + self.assertEqual(config.unlock(pin_to_int(''), None), True) for _ in range(128): appid, key = random_entry() value = config.get(appid, key)