1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-26 07:11:25 +00:00

core: Implement SD card protection.

This commit is contained in:
Andrew Kozlik 2019-08-13 17:50:06 +02:00
parent f867b43251
commit 6350b1c61c
13 changed files with 515 additions and 66 deletions

View File

@ -67,40 +67,38 @@ STATIC mp_obj_t mod_trezorconfig_init(size_t n_args, const mp_obj_t *args) {
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_init_obj, 0, 1,
mod_trezorconfig_init); mod_trezorconfig_init);
/// def unlock(pin: int, ext_salt: Optional[bytes] = None) -> bool: /// def unlock(pin: int, ext_salt: Optional[bytes]) -> bool:
/// """ /// """
/// Attempts to unlock the storage with the given PIN and external salt. /// Attempts to unlock the storage with the given PIN and external salt.
/// Returns True on success, False on failure. /// Returns True on success, False on failure.
/// """ /// """
STATIC mp_obj_t mod_trezorconfig_unlock(size_t n_args, const mp_obj_t *args) { STATIC mp_obj_t mod_trezorconfig_unlock(mp_obj_t pin, mp_obj_t ext_salt) {
uint32_t pin = trezor_obj_get_uint(args[0]); uint32_t pin_i = trezor_obj_get_uint(pin);
const uint8_t *ext_salt = NULL;
if (n_args > 1 && args[1] != mp_const_none) {
mp_buffer_info_t ext_salt_b; mp_buffer_info_t ext_salt_b;
mp_get_buffer_raise(args[1], &ext_salt_b, MP_BUFFER_READ); ext_salt_b.buf = NULL;
if (ext_salt != mp_const_none) {
mp_get_buffer_raise(ext_salt, &ext_salt_b, MP_BUFFER_READ);
if (ext_salt_b.len != EXTERNAL_SALT_SIZE) if (ext_salt_b.len != EXTERNAL_SALT_SIZE)
mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt."); mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt.");
ext_salt = ext_salt_b.buf;
} }
if (sectrue != storage_unlock(pin, ext_salt)) { if (sectrue != storage_unlock(pin_i, ext_salt_b.buf)) {
return mp_const_false; return mp_const_false;
} }
return mp_const_true; return mp_const_true;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_unlock_obj, 1, 2, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_unlock_obj,
mod_trezorconfig_unlock); mod_trezorconfig_unlock);
/// def check_pin(pin: int, ext_salt: Optional[bytes] = None) -> bool: /// def check_pin(pin: int, ext_salt: Optional[bytes]) -> bool:
/// """ /// """
/// Check the given PIN with the given external salt. /// Check the given PIN with the given external salt.
/// Returns True on success, False on failure. /// Returns True on success, False on failure.
/// """ /// """
STATIC mp_obj_t mod_trezorconfig_check_pin(size_t n_args, STATIC mp_obj_t mod_trezorconfig_check_pin(mp_obj_t pin, mp_obj_t ext_salt) {
const mp_obj_t *args) { return mod_trezorconfig_unlock(pin, ext_salt);
return mod_trezorconfig_unlock(n_args, args);
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_check_pin_obj, 1, 2, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorconfig_check_pin_obj,
mod_trezorconfig_check_pin); mod_trezorconfig_check_pin);
/// def lock() -> None: /// def lock() -> None:
@ -140,8 +138,8 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_0(mod_trezorconfig_get_pin_rem_obj,
/// def change_pin( /// def change_pin(
/// oldpin: int, /// oldpin: int,
/// newpin: int, /// newpin: int,
/// old_ext_salt: Optional[bytes] = None, /// old_ext_salt: Optional[bytes],
/// new_ext_salt: Optional[bytes] = None, /// new_ext_salt: Optional[bytes],
/// ) -> bool: /// ) -> bool:
/// """ /// """
/// Change PIN and external salt. Returns True on success, False on failure. /// Change PIN and external salt. Returns True on success, False on failure.
@ -152,14 +150,14 @@ STATIC mp_obj_t mod_trezorconfig_change_pin(size_t n_args,
uint32_t newpin = trezor_obj_get_uint(args[1]); uint32_t newpin = trezor_obj_get_uint(args[1]);
mp_buffer_info_t ext_salt_b; mp_buffer_info_t ext_salt_b;
const uint8_t *old_ext_salt = NULL; const uint8_t *old_ext_salt = NULL;
if (n_args > 2 && args[2] != mp_const_none) { if (args[2] != mp_const_none) {
mp_get_buffer_raise(args[2], &ext_salt_b, MP_BUFFER_READ); mp_get_buffer_raise(args[2], &ext_salt_b, MP_BUFFER_READ);
if (ext_salt_b.len != EXTERNAL_SALT_SIZE) if (ext_salt_b.len != EXTERNAL_SALT_SIZE)
mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt."); mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt.");
old_ext_salt = ext_salt_b.buf; old_ext_salt = ext_salt_b.buf;
} }
const uint8_t *new_ext_salt = NULL; const uint8_t *new_ext_salt = NULL;
if (n_args > 3 && args[3] != mp_const_none) { if (args[3] != mp_const_none) {
mp_get_buffer_raise(args[3], &ext_salt_b, MP_BUFFER_READ); mp_get_buffer_raise(args[3], &ext_salt_b, MP_BUFFER_READ);
if (ext_salt_b.len != EXTERNAL_SALT_SIZE) if (ext_salt_b.len != EXTERNAL_SALT_SIZE)
mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt."); mp_raise_msg(&mp_type_ValueError, "Invalid length of external salt.");
@ -172,7 +170,7 @@ STATIC mp_obj_t mod_trezorconfig_change_pin(size_t n_args,
} }
return mp_const_true; return mp_const_true;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_change_pin_obj, 2, STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_trezorconfig_change_pin_obj, 4,
4, mod_trezorconfig_change_pin); 4, mod_trezorconfig_change_pin);
/// def get(app: int, key: int, public: bool = False) -> Optional[bytes]: /// def get(app: int, key: int, public: bool = False) -> Optional[bytes]:

View File

@ -12,7 +12,7 @@ def init(
# extmod/modtrezorconfig/modtrezorconfig.c # extmod/modtrezorconfig/modtrezorconfig.c
def unlock(pin: int, ext_salt: Optional[bytes] = None) -> bool: def unlock(pin: int, ext_salt: Optional[bytes]) -> bool:
""" """
Attempts to unlock the storage with the given PIN and external salt. Attempts to unlock the storage with the given PIN and external salt.
Returns True on success, False on failure. Returns True on success, False on failure.
@ -20,7 +20,7 @@ def unlock(pin: int, ext_salt: Optional[bytes] = None) -> bool:
# extmod/modtrezorconfig/modtrezorconfig.c # extmod/modtrezorconfig/modtrezorconfig.c
def check_pin(pin: int, ext_salt: Optional[bytes] = None) -> bool: def check_pin(pin: int, ext_salt: Optional[bytes]) -> bool:
""" """
Check the given PIN with the given external salt. Check the given PIN with the given external salt.
Returns True on success, False on failure. Returns True on success, False on failure.
@ -52,8 +52,8 @@ def get_pin_rem() -> int:
def change_pin( def change_pin(
oldpin: int, oldpin: int,
newpin: int, newpin: int,
old_ext_salt: Optional[bytes] = None, old_ext_salt: Optional[bytes],
new_ext_salt: Optional[bytes] = None, new_ext_salt: Optional[bytes],
) -> bool: ) -> bool:
""" """
Change PIN and external salt. Returns True on success, False on failure. Change PIN and external salt. Returns True on success, False on failure.

View File

@ -0,0 +1,193 @@
from micropython import const
from trezor import io, ui, wire
from trezor.crypto import hmac
from trezor.crypto.hashlib import sha256
from trezor.ui.confirm import Confirm
from trezor.ui.text import Text
from trezor.utils import consteq
from apps.common import storage
from apps.common.confirm import require_confirm
if False:
from typing import Optional
class SdProtectCancelled(Exception):
pass
SD_SALT_LEN_BYTES = const(32)
SD_SALT_AUTH_TAG_LEN_BYTES = const(16)
SD_SALT_AUTH_KEY_LEN_BYTES = const(16)
async def wrong_card_dialog(ctx: Optional[wire.Context]) -> None:
text = Text("SD card protection", ui.ICON_WRONG)
text.bold("Wrong SD card.")
text.br_half()
text.normal("Please unplug the", "device and insert a", "different card.")
if ctx is None:
await Confirm(text, confirm=None)
else:
await require_confirm(ctx, text, confirm=None)
async def insert_card_dialog(ctx: Optional[wire.Context]) -> None:
text = Text("SD card protection")
text.bold("SD card required.")
text.br_half()
text.normal("Please unplug the", "device and insert your", "SD card.")
if ctx is None:
await Confirm(text, confirm=None)
else:
await require_confirm(ctx, text, confirm=None)
async def request_sd_salt(
ctx: Optional[wire.Context], salt_auth_key: bytes
) -> bytearray:
device_dir = "/trezor/device_%s" % storage.device.get_device_id()
salt_path = "%s/salt" % device_dir
new_salt_path = "%s/salt.new" % device_dir
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
await insert_card_dialog(ctx)
raise SdProtectCancelled
try:
fs.mount()
# Load salt if it exists.
try:
with fs.open(salt_path, "r") as f:
salt = bytearray(SD_SALT_LEN_BYTES) # type: Optional[bytearray]
salt_tag = bytearray(SD_SALT_AUTH_TAG_LEN_BYTES)
f.read(salt)
f.read(salt_tag)
except OSError:
salt = None
if salt is not None and consteq(
hmac.new(salt_auth_key, salt, sha256).digest()[:SD_SALT_AUTH_TAG_LEN_BYTES],
salt_tag,
):
return salt
# Load salt.new if it exists.
try:
with fs.open(new_salt_path, "r") as f:
new_salt = bytearray(SD_SALT_LEN_BYTES) # type: Optional[bytearray]
new_salt_tag = bytearray(SD_SALT_AUTH_TAG_LEN_BYTES)
f.read(new_salt)
f.read(new_salt_tag)
except OSError:
new_salt = None
if new_salt is not None and consteq(
hmac.new(salt_auth_key, new_salt, sha256).digest()[
:SD_SALT_AUTH_TAG_LEN_BYTES
],
new_salt_tag,
):
# SD salt regeneration was interrupted earlier. Bring into consistent state.
# TODO Possibly overwrite salt file with random data.
try:
fs.unlink(salt_path)
except OSError:
pass
fs.rename(new_salt_path, salt_path)
return new_salt
finally:
fs.unmount()
sd.power(False)
await wrong_card_dialog(ctx)
raise SdProtectCancelled
async def set_sd_salt(
ctx: Optional[wire.Context], salt: bytes, salt_tag: bytes, filename: str = "salt"
) -> None:
device_dir = "/trezor/device_%s" % storage.device.get_device_id()
salt_path = "%s/%s" % (device_dir, filename)
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
await insert_card_dialog(ctx)
raise SdProtectCancelled
try:
fs.mount()
try:
fs.mkdir("/trezor")
except OSError:
# Directory already exists.
pass
try:
fs.mkdir(device_dir)
except OSError:
# Directory already exists.
pass
with fs.open(salt_path, "w") as f:
f.write(salt)
f.write(salt_tag)
finally:
fs.unmount()
sd.power(False)
async def stage_sd_salt(
ctx: Optional[wire.Context], salt: bytes, salt_tag: bytes
) -> None:
await set_sd_salt(ctx, salt, salt_tag, "salt.new")
async def commit_sd_salt(ctx: Optional[wire.Context]) -> None:
device_dir = "/trezor/device_%s" % storage.device.get_device_id()
salt_path = "%s/salt" % device_dir
new_salt_path = "%s/salt.new" % device_dir
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
await insert_card_dialog(ctx)
raise SdProtectCancelled
try:
fs.mount()
# TODO Possibly overwrite salt file with random data.
try:
fs.unlink(salt_path)
except OSError:
pass
fs.rename(new_salt_path, salt_path)
finally:
fs.unmount()
sd.power(False)
async def remove_sd_salt(ctx: Optional[wire.Context]) -> None:
device_dir = "/trezor/device_%s" % storage.device.get_device_id()
salt_path = "%s/salt" % device_dir
sd = io.SDCard()
fs = io.FatFS()
if not sd.power(True):
await insert_card_dialog(ctx)
raise SdProtectCancelled
try:
fs.mount()
# TODO Possibly overwrite salt file with random data.
fs.unlink(salt_path)
finally:
fs.unmount()
sd.power(False)

View File

@ -28,8 +28,8 @@ def get(app: int, key: int, public: bool = False) -> Optional[bytes]:
return config.get(app, key, public) return config.get(app, key, public)
def delete(app: int, key: int) -> None: def delete(app: int, key: int, public: bool = False) -> None:
config.delete(app, key) config.delete(app, key, public)
def set_true_or_delete(app: int, key: int, value: bool) -> None: def set_true_or_delete(app: int, key: int, value: bool) -> None:

View File

@ -3,6 +3,7 @@ from ubinascii import hexlify
from trezor.crypto import random from trezor.crypto import random
from apps.common.sd_salt import SD_SALT_AUTH_KEY_LEN_BYTES
from apps.common.storage import common from apps.common.storage import common
if False: if False:
@ -31,6 +32,7 @@ _MNEMONIC_TYPE = const(0x0E) # int
_ROTATION = const(0x0F) # int _ROTATION = const(0x0F) # int
_SLIP39_IDENTIFIER = const(0x10) # bool _SLIP39_IDENTIFIER = const(0x10) # bool
_SLIP39_ITERATION_EXPONENT = const(0x11) # int _SLIP39_ITERATION_EXPONENT = const(0x11) # int
_SD_SALT_AUTH_KEY = const(0x12) # bytes
# fmt: on # fmt: on
HOMESCREEN_MAXSIZE = 16384 HOMESCREEN_MAXSIZE = 16384
@ -234,3 +236,25 @@ def get_slip39_iteration_exponent() -> Optional[int]:
The device's actual SLIP-39 iteration exponent used in passphrase derivation. The device's actual SLIP-39 iteration exponent used in passphrase derivation.
""" """
return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT) return common.get_uint8(_NAMESPACE, _SLIP39_ITERATION_EXPONENT)
def get_sd_salt_auth_key() -> Optional[bytes]:
"""
The key used to check the authenticity of the SD card salt.
"""
auth_key = common.get(_NAMESPACE, _SD_SALT_AUTH_KEY, public=True)
if auth_key is not None and len(auth_key) != SD_SALT_AUTH_KEY_LEN_BYTES:
raise ValueError
return auth_key
def set_sd_salt_auth_key(auth_key: Optional[bytes]) -> None:
"""
The key used to check the authenticity of the SD card salt.
"""
if auth_key is not None:
if len(auth_key) != SD_SALT_AUTH_KEY_LEN_BYTES:
raise ValueError
return common.set(_NAMESPACE, _SD_SALT_AUTH_KEY, auth_key, public=True)
else:
return common.delete(_NAMESPACE, _SD_SALT_AUTH_KEY, public=True)

View File

@ -14,3 +14,4 @@ def boot() -> None:
wire.add(MessageType.ApplyFlags, __name__, "apply_flags") wire.add(MessageType.ApplyFlags, __name__, "apply_flags")
wire.add(MessageType.ChangePin, __name__, "change_pin") wire.add(MessageType.ChangePin, __name__, "change_pin")
wire.add(MessageType.SetU2FCounter, __name__, "set_u2f_counter") wire.add(MessageType.SetU2FCounter, __name__, "set_u2f_counter")
wire.add(MessageType.SdProtect, __name__, "sd_protect")

View File

@ -9,26 +9,25 @@ from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.request_pin import PinCancelled, request_pin from apps.common.request_pin import PinCancelled, request_pin
from apps.common.sd_salt import request_sd_salt
from apps.common.storage import device
if False: if False:
from typing import Any from typing import Any, Optional, Tuple
from trezor.messages.ChangePin import ChangePin from trezor.messages.ChangePin import ChangePin
async def change_pin(ctx: wire.Context, msg: ChangePin) -> Success: async def change_pin(ctx: wire.Context, msg: ChangePin) -> Success:
# confirm that user wants to change the pin # confirm that user wants to change the pin
await require_confirm_change_pin(ctx, msg) await require_confirm_change_pin(ctx, msg)
# get current pin, return failure if invalid # get old pin
if config.has_pin(): curpin, salt = await request_pin_and_sd_salt(ctx, "Enter old PIN")
curpin = await request_pin_ack(ctx, "Enter old PIN", config.get_pin_rem())
# if removing, defer check to change_pin() # if changing pin, pre-check the entered pin before getting new pin
if not msg.remove: if curpin and not msg.remove:
if not config.check_pin(pin_to_int(curpin)): if not config.check_pin(pin_to_int(curpin), salt):
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
else:
curpin = ""
# get new pin # get new pin
if not msg.remove: if not msg.remove:
@ -37,7 +36,7 @@ async def change_pin(ctx: wire.Context, msg: ChangePin) -> Success:
newpin = "" newpin = ""
# write into storage # write into storage
if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin)): if not config.change_pin(pin_to_int(curpin), pin_to_int(newpin), salt, salt):
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
if newpin: if newpin:
@ -77,6 +76,23 @@ async def request_pin_confirm(ctx: wire.Context, *args: Any, **kwargs: Any) -> s
await pin_mismatch() await pin_mismatch()
async def request_pin_and_sd_salt(
ctx: wire.Context, prompt: str = "Enter your PIN", allow_cancel: bool = True
) -> Tuple[str, Optional[bytearray]]:
salt_auth_key = device.get_sd_salt_auth_key()
if salt_auth_key is not None:
salt = await request_sd_salt(ctx, salt_auth_key) # type: Optional[bytearray]
else:
salt = None
if config.has_pin():
pin = await request_pin_ack(ctx, prompt, config.get_pin_rem(), allow_cancel)
else:
pin = ""
return pin, salt
async def request_pin_ack(ctx: wire.Context, *args: Any, **kwargs: Any) -> str: async def request_pin_ack(ctx: wire.Context, *args: Any, **kwargs: Any) -> str:
try: try:
await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.Other), ButtonAck)

View File

@ -55,6 +55,6 @@ async def load_device(ctx, msg):
use_passphrase=msg.passphrase_protection, label=msg.label use_passphrase=msg.passphrase_protection, label=msg.label
) )
if msg.pin: if msg.pin:
config.change_pin(pin_to_int(""), pin_to_int(msg.pin)) config.change_pin(pin_to_int(""), pin_to_int(msg.pin), None, None)
return Success(message="Device loaded") return Success(message="Device loaded")

View File

@ -6,7 +6,7 @@ from trezor.ui.text import Text
from apps.common import storage from apps.common import storage
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.management.change_pin import request_pin_ack, request_pin_confirm from apps.management.change_pin import request_pin_and_sd_salt, request_pin_confirm
from apps.management.recovery_device.homescreen import recovery_process from apps.management.recovery_device.homescreen import recovery_process
if False: if False:
@ -24,13 +24,10 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
await _continue_dialog(ctx, msg) await _continue_dialog(ctx, msg)
# for dry run pin needs to entered # for dry run pin needs to be entered
if msg.dry_run: if msg.dry_run:
if config.has_pin(): curpin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
curpin = await request_pin_ack(ctx, "Enter PIN", config.get_pin_rem()) if not config.check_pin(pin_to_int(curpin), salt):
else:
curpin = ""
if not config.check_pin(pin_to_int(curpin)):
raise wire.PinInvalid("PIN invalid") raise wire.PinInvalid("PIN invalid")
# set up pin if requested # set up pin if requested
@ -38,7 +35,7 @@ async def recovery_device(ctx: wire.Context, msg: RecoveryDevice) -> Success:
if msg.dry_run: if msg.dry_run:
raise wire.ProcessError("Can't setup PIN during dry_run recovery.") raise wire.ProcessError("Can't setup PIN during dry_run recovery.")
newpin = await request_pin_confirm(ctx, allow_cancel=False) newpin = await request_pin_confirm(ctx, allow_cancel=False)
config.change_pin(pin_to_int(""), pin_to_int(newpin)) config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None)
if msg.u2f_counter: if msg.u2f_counter:
storage.device.set_u2f_counter(msg.u2f_counter) storage.device.set_u2f_counter(msg.u2f_counter)

View File

@ -71,7 +71,7 @@ async def reset_device(ctx: wire.Context, msg: ResetDevice) -> Success:
await backup_bip39_wallet(ctx, secret) await backup_bip39_wallet(ctx, secret)
# write PIN into storage # write PIN into storage
if not config.change_pin(pin_to_int(""), pin_to_int(newpin)): if not config.change_pin(pin_to_int(""), pin_to_int(newpin), None, None):
raise wire.ProcessError("Could not change PIN") raise wire.ProcessError("Could not change PIN")
# write settings and master secret into storage # write settings and master secret into storage

View File

@ -0,0 +1,169 @@
from trezor import config, ui, wire
from trezor.crypto import hmac, random
from trezor.crypto.hashlib import sha256
from trezor.messages import SdProtectOperationType
from trezor.messages.Success import Success
from trezor.pin import pin_to_int
from trezor.ui.text import Text
from apps.common.confirm import require_confirm
from apps.common.sd_salt import (
SD_SALT_AUTH_KEY_LEN_BYTES,
SD_SALT_AUTH_TAG_LEN_BYTES,
SD_SALT_LEN_BYTES,
commit_sd_salt,
remove_sd_salt,
set_sd_salt,
stage_sd_salt,
)
from apps.common.storage import device, is_initialized
from apps.management.change_pin import request_pin_ack, request_pin_and_sd_salt
if False:
from trezor.messages.SdProtect import SdProtect
async def sd_protect(ctx: wire.Context, msg: SdProtect) -> Success:
if not is_initialized():
raise wire.ProcessError("Device is not initialized")
if msg.operation == SdProtectOperationType.ENABLE:
return await sd_protect_enable(ctx, msg)
elif msg.operation == SdProtectOperationType.DISABLE:
return await sd_protect_disable(ctx, msg)
elif msg.operation == SdProtectOperationType.REFRESH:
return await sd_protect_refresh(ctx, msg)
else:
raise wire.ProcessError("Unknown operation")
async def sd_protect_enable(ctx: wire.Context, msg: SdProtect) -> Success:
salt_auth_key = device.get_sd_salt_auth_key()
if salt_auth_key is not None:
raise wire.ProcessError("SD card protection already enabled")
# Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg)
# Get the current PIN.
if config.has_pin():
pin = pin_to_int(await request_pin_ack(ctx, "Enter PIN", config.get_pin_rem()))
else:
pin = pin_to_int("")
# Check PIN and prepare salt file.
salt = random.bytes(SD_SALT_LEN_BYTES)
salt_auth_key = random.bytes(SD_SALT_AUTH_KEY_LEN_BYTES)
salt_tag = hmac.new(salt_auth_key, salt, sha256).digest()[
:SD_SALT_AUTH_TAG_LEN_BYTES
]
try:
await set_sd_salt(ctx, salt, salt_tag)
except Exception:
raise wire.ProcessError("Failed to write to SD card")
if not config.change_pin(pin, pin, None, salt):
# Wrong PIN. Clean up the prepared salt file.
try:
await remove_sd_salt(ctx)
except Exception:
# The cleanup is not necessary for the correct functioning of
# SD-protection. If it fails for any reason, we suppress the
# exception, because primarily we need to raise wire.PinInvalid.
pass
raise wire.PinInvalid("PIN invalid")
device.set_sd_salt_auth_key(salt_auth_key)
return Success(message="SD card protection enabled")
async def sd_protect_disable(ctx: wire.Context, msg: SdProtect) -> Success:
if device.get_sd_salt_auth_key() is None:
raise wire.ProcessError("SD card protection not enabled")
# Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg)
# Get the current PIN and salt from the SD card.
pin, salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
# Check PIN and remove salt.
if not config.change_pin(pin_to_int(pin), pin_to_int(pin), salt, None):
raise wire.PinInvalid("PIN invalid")
device.set_sd_salt_auth_key(None)
try:
# Clean up.
await remove_sd_salt(ctx)
except Exception:
# The cleanup is not necessary for the correct functioning of
# SD-protection. If it fails for any reason, we suppress the exception,
# because overall SD-protection was successfully disabled.
pass
return Success(message="SD card protection disabled")
async def sd_protect_refresh(ctx: wire.Context, msg: SdProtect) -> Success:
if device.get_sd_salt_auth_key() is None:
raise wire.ProcessError("SD card protection not enabled")
# Confirm that user wants to proceed with the operation.
await require_confirm_sd_protect(ctx, msg)
# Get the current PIN and salt from the SD card.
pin, old_salt = await request_pin_and_sd_salt(ctx, "Enter PIN")
# Check PIN and change salt.
new_salt = random.bytes(SD_SALT_LEN_BYTES)
new_salt_auth_key = random.bytes(SD_SALT_AUTH_KEY_LEN_BYTES)
new_salt_tag = hmac.new(new_salt_auth_key, new_salt, sha256).digest()[
:SD_SALT_AUTH_TAG_LEN_BYTES
]
try:
await stage_sd_salt(ctx, new_salt, new_salt_tag)
except Exception:
raise wire.ProcessError("Failed to write to SD card")
if not config.change_pin(pin_to_int(pin), pin_to_int(pin), old_salt, new_salt):
raise wire.PinInvalid("PIN invalid")
device.set_sd_salt_auth_key(new_salt_auth_key)
try:
# Clean up.
await commit_sd_salt(ctx)
except Exception:
# If the cleanup fails, then request_sd_salt() will bring the SD card
# into a consistent state. We suppress the exception, because overall
# SD-protection was successfully refreshed.
pass
return Success(message="SD card protection refreshed")
def require_confirm_sd_protect(ctx: wire.Context, msg: SdProtect) -> None:
if msg.operation == SdProtectOperationType.ENABLE:
text = Text("SD card protection", ui.ICON_CONFIG)
text.normal(
"Do you really want to", "secure your device with", "SD card protection?"
)
elif msg.operation == SdProtectOperationType.DISABLE:
text = Text("SD card protection", ui.ICON_CONFIG)
text.normal(
"Do you really want to", "remove SD card", "protection from your", "device?"
)
elif msg.operation == SdProtectOperationType.REFRESH:
text = Text("SD card protection", ui.ICON_CONFIG)
text.normal(
"Do you really want to",
"replace the current",
"SD card secret with a",
"newly generated one?",
)
else:
raise wire.ProcessError("Unknown operation")
return require_confirm(ctx, text)

View File

@ -3,21 +3,38 @@ from trezor.pin import pin_to_int, show_pin_timeout
from apps.common import storage from apps.common import storage
from apps.common.request_pin import request_pin from apps.common.request_pin import request_pin
from apps.common.sd_salt import request_sd_salt
from apps.common.storage import device
if False:
from typing import Optional
async def bootscreen() -> None: async def bootscreen() -> None:
ui.display.orientation(storage.device.get_rotation()) ui.display.orientation(storage.device.get_rotation())
salt_auth_key = device.get_sd_salt_auth_key()
while True: while True:
try: try:
if salt_auth_key is not None or config.has_pin():
await lockscreen()
if salt_auth_key is not None:
salt = await request_sd_salt(
None, salt_auth_key
) # type: Optional[bytearray]
else:
salt = None
if not config.has_pin(): if not config.has_pin():
config.unlock(pin_to_int("")) config.unlock(pin_to_int(""), salt)
storage.init_unlocked() storage.init_unlocked()
return return
await lockscreen()
label = "Enter your PIN" label = "Enter your PIN"
while True: while True:
pin = await request_pin(label, config.get_pin_rem()) pin = await request_pin(label, config.get_pin_rem())
if config.unlock(pin_to_int(pin)): if config.unlock(pin_to_int(pin), salt):
storage.init_unlocked() storage.init_unlocked()
return return
else: else:

View File

@ -27,7 +27,7 @@ class TestConfig(unittest.TestCase):
def test_wipe(self): def test_wipe(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
config.set(1, 1, b'hello') config.set(1, 1, b'hello')
config.set(1, 2, b'world') config.set(1, 2, b'world')
v0 = config.get(1, 1) v0 = config.get(1, 1)
@ -44,7 +44,7 @@ class TestConfig(unittest.TestCase):
for _ in range(128): for _ in range(128):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
appid, key = random_entry() appid, key = random_entry()
value = random.bytes(16) value = random.bytes(16)
config.set(appid, key, value) config.set(appid, key, value)
@ -58,7 +58,7 @@ class TestConfig(unittest.TestCase):
def test_public(self): def test_public(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
appid, key = random_entry() appid, key = random_entry()
@ -84,25 +84,59 @@ class TestConfig(unittest.TestCase):
def test_change_pin(self): def test_change_pin(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
config.set(PINAPP, PINKEY, b'value') config.set(PINAPP, PINKEY, b'value')
self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666')), False) self.assertEqual(config.change_pin(pin_to_int('000'), pin_to_int('666'), None, None), False)
self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000')), True) self.assertEqual(config.change_pin(pin_to_int(''), pin_to_int('000'), None, None), True)
self.assertEqual(config.get(PINAPP, PINKEY), None) self.assertEqual(config.get(PINAPP, PINKEY), None)
config.set(1, 1, b'value') config.set(1, 1, b'value')
config.init() config.init()
self.assertEqual(config.unlock(pin_to_int('000')), True) self.assertEqual(config.unlock(pin_to_int('000'), None), True)
config.change_pin(pin_to_int('000'), pin_to_int('')) config.change_pin(pin_to_int('000'), pin_to_int(''), None, None)
config.init() config.init()
self.assertEqual(config.unlock(pin_to_int('000')), False) self.assertEqual(config.unlock(pin_to_int('000'), None), False)
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
self.assertEqual(config.get(1, 1), b'value')
def test_change_sd_salt(self):
salt1 = b"0123456789abcdef0123456789abcdef"
salt2 = b"0123456789ABCDEF0123456789ABCDEF"
# Enable PIN and SD salt.
config.init()
config.wipe()
self.assertTrue(config.unlock(pin_to_int(''), None))
config.set(1, 1, b'value')
self.assertFalse(config.change_pin(pin_to_int(''), pin_to_int(''), salt1, None))
self.assertTrue(config.change_pin(pin_to_int(''), pin_to_int('000'), None, salt1))
self.assertEqual(config.get(1, 1), b'value')
# Disable PIN and change SD salt.
config.init()
self.assertFalse(config.unlock(pin_to_int('000'), None))
self.assertIsNone(config.get(1, 1))
self.assertTrue(config.unlock(pin_to_int('000'), salt1))
self.assertTrue(config.change_pin(pin_to_int('000'), pin_to_int(''), salt1, salt2))
self.assertEqual(config.get(1, 1), b'value')
# Disable SD salt.
config.init()
self.assertFalse(config.unlock(pin_to_int('000'), salt2))
self.assertIsNone(config.get(1, 1))
self.assertTrue(config.unlock(pin_to_int(''), salt2))
self.assertTrue(config.change_pin(pin_to_int(''), pin_to_int(''), salt2, None))
self.assertEqual(config.get(1, 1), b'value')
# Check that PIN and SD salt are disabled.
config.init()
self.assertTrue(config.unlock(pin_to_int(''), None))
self.assertEqual(config.get(1, 1), b'value') self.assertEqual(config.get(1, 1), b'value')
def test_set_get(self): def test_set_get(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
for _ in range(32): for _ in range(32):
appid, key = random_entry() appid, key = random_entry()
value = random.bytes(128) value = random.bytes(128)
@ -113,7 +147,7 @@ class TestConfig(unittest.TestCase):
def test_compact(self): def test_compact(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
appid, key = 1, 1 appid, key = 1, 1
for _ in range(259): for _ in range(259):
value = random.bytes(259) value = random.bytes(259)
@ -124,7 +158,7 @@ class TestConfig(unittest.TestCase):
def test_get_default(self): def test_get_default(self):
config.init() config.init()
config.wipe() config.wipe()
self.assertEqual(config.unlock(pin_to_int('')), True) self.assertEqual(config.unlock(pin_to_int(''), None), True)
for _ in range(128): for _ in range(128):
appid, key = random_entry() appid, key = random_entry()
value = config.get(appid, key) value = config.get(appid, key)