diff --git a/common/protob/messages-management.proto b/common/protob/messages-management.proto index 7fc64eebd0..e26d84ea38 100644 --- a/common/protob/messages-management.proto +++ b/common/protob/messages-management.proto @@ -87,6 +87,8 @@ message Features { Capability_ShamirGroups = 16; } optional BackupType backup_type = 31; // type of device backup (BIP-39 / SLIP-39 basic / SLIP-39 advanced) + optional bool sd_card_present = 32; // is SD card present + optional bool sd_protection = 33; // is SD Protect enabled } /** diff --git a/core/src/apps/homescreen/__init__.py b/core/src/apps/homescreen/__init__.py index 339d7461bc..4d8e78598c 100644 --- a/core/src/apps/homescreen/__init__.py +++ b/core/src/apps/homescreen/__init__.py @@ -1,4 +1,4 @@ -from trezor import config, utils, wire +from trezor import config, io, utils, wire from trezor.messages import Capability, MessageType from trezor.messages.Features import Features from trezor.messages.Success import Success @@ -63,6 +63,8 @@ def get_features() -> Features: Capability.Shamir, Capability.ShamirGroups, ] + f.sd_card_present = io.SDCard().present() + f.sd_protection = storage.device.get_sd_salt_auth_key() is not None return f diff --git a/core/src/trezor/messages/Features.py b/core/src/trezor/messages/Features.py index 7f890f8cb5..191b5c69c5 100644 --- a/core/src/trezor/messages/Features.py +++ b/core/src/trezor/messages/Features.py @@ -49,6 +49,8 @@ class Features(p.MessageType): recovery_mode: bool = None, capabilities: List[EnumTypeCapability] = None, backup_type: EnumTypeBackupType = None, + sd_card_present: bool = None, + sd_protection: bool = None, ) -> None: self.vendor = vendor self.major_version = major_version @@ -80,6 +82,8 @@ class Features(p.MessageType): self.recovery_mode = recovery_mode self.capabilities = capabilities if capabilities is not None else [] self.backup_type = backup_type + self.sd_card_present = sd_card_present + self.sd_protection = sd_protection @classmethod def get_fields(cls) -> Dict: @@ -114,4 +118,6 @@ class Features(p.MessageType): 29: ('recovery_mode', p.BoolType, 0), 30: ('capabilities', p.EnumType("Capability", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), p.FLAG_REPEATED), 31: ('backup_type', p.EnumType("BackupType", (0, 1, 2)), 0), + 32: ('sd_card_present', p.BoolType, 0), + 33: ('sd_protection', p.BoolType, 0), } diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index 357dfdfd24..ac7d650d7c 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -93,6 +93,7 @@ def change_pin(client, remove=False): @expect(proto.Success, field="message") def sd_protect(client, operation): ret = client.call(proto.SdProtect(operation=operation)) + client.init_device() return ret diff --git a/python/src/trezorlib/messages/Features.py b/python/src/trezorlib/messages/Features.py index 3cf6649f7c..b21c39dfc6 100644 --- a/python/src/trezorlib/messages/Features.py +++ b/python/src/trezorlib/messages/Features.py @@ -49,6 +49,8 @@ class Features(p.MessageType): recovery_mode: bool = None, capabilities: List[EnumTypeCapability] = None, backup_type: EnumTypeBackupType = None, + sd_card_present: bool = None, + sd_protection: bool = None, ) -> None: self.vendor = vendor self.major_version = major_version @@ -80,6 +82,8 @@ class Features(p.MessageType): self.recovery_mode = recovery_mode self.capabilities = capabilities if capabilities is not None else [] self.backup_type = backup_type + self.sd_card_present = sd_card_present + self.sd_protection = sd_protection @classmethod def get_fields(cls) -> Dict: @@ -114,4 +118,6 @@ class Features(p.MessageType): 29: ('recovery_mode', p.BoolType, 0), 30: ('capabilities', p.EnumType("Capability", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), p.FLAG_REPEATED), 31: ('backup_type', p.EnumType("BackupType", (0, 1, 2)), 0), + 32: ('sd_card_present', p.BoolType, 0), + 33: ('sd_protection', p.BoolType, 0), } diff --git a/tests/REGISTERED_MARKERS b/tests/REGISTERED_MARKERS index 4cff69b25e..e73668e237 100644 --- a/tests/REGISTERED_MARKERS +++ b/tests/REGISTERED_MARKERS @@ -11,6 +11,7 @@ monero nem ontology ripple +sd_card stellar tezos zcash diff --git a/tests/conftest.py b/tests/conftest.py index 25493f4459..faa1ad2508 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,6 +77,16 @@ def client(request): if request.node.get_closest_marker("skip_t1") and client.features.model == "1": pytest.skip("Test excluded on Trezor 1") + if ( + request.node.get_closest_marker("sd_card") + and not client.features.sd_card_present + ): + raise RuntimeError( + "This test requires SD card.\n" + "To skip all such tests, run:\n" + " pytest -m 'not sd_card' " + ) + wipe_device(client) # fmt: off @@ -141,7 +151,7 @@ def pytest_runtest_setup(item): both T1 and TT. """ if item.get_closest_marker("skip_t1") and item.get_closest_marker("skip_t2"): - pytest.fail("Don't skip tests for both trezors!") + raise RuntimeError("Don't skip tests for both trezors!") skip_altcoins = int(os.environ.get("TREZOR_PYTEST_SKIP_ALTCOINS", 0)) if item.get_closest_marker("altcoin") and skip_altcoins: diff --git a/tests/device_tests/test_msg_sd_protect.py b/tests/device_tests/test_msg_sd_protect.py index 943d968fcf..324795b083 100644 --- a/tests/device_tests/test_msg_sd_protect.py +++ b/tests/device_tests/test_msg_sd_protect.py @@ -16,47 +16,73 @@ import pytest -from trezorlib import debuglink, device, messages as proto +from trezorlib import debuglink, device from trezorlib.exceptions import TrezorFailure +from trezorlib.messages import SdProtectOperationType as Op from ..common import MNEMONIC12 +pytestmark = [pytest.mark.skip_t1, pytest.mark.sd_card] -@pytest.mark.skip_t1 -class TestMsgSdProtect: - @pytest.mark.setup_client(mnemonic=MNEMONIC12) - def test_sd_protect(self, client): - # Disabling SD protection should fail - with pytest.raises(TrezorFailure): - device.sd_protect(client, proto.SdProtectOperationType.DISABLE) +def test_enable_disable(client): + assert client.features.sd_protection is False + # Disabling SD protection should fail + with pytest.raises(TrezorFailure): + device.sd_protect(client, Op.DISABLE) - # Enable SD protection - device.sd_protect(client, proto.SdProtectOperationType.ENABLE) + # Enable SD protection + device.sd_protect(client, Op.ENABLE) + assert client.features.sd_protection is True - # Enabling SD protection should fail - with pytest.raises(TrezorFailure): - device.sd_protect(client, proto.SdProtectOperationType.ENABLE) + # Enabling SD protection should fail + with pytest.raises(TrezorFailure): + device.sd_protect(client, Op.ENABLE) + assert client.features.sd_protection is True - # Wipe - device.wipe(client) - debuglink.load_device_by_mnemonic( - client, - mnemonic=MNEMONIC12, - pin="", - passphrase_protection=False, - label="test", - ) + # Disable SD protection + device.sd_protect(client, Op.DISABLE) + assert client.features.sd_protection is False - # Enable SD protection - device.sd_protect(client, proto.SdProtectOperationType.ENABLE) - # Refresh SD protection - device.sd_protect(client, proto.SdProtectOperationType.REFRESH) +def test_refresh(client): + assert client.features.sd_protection is False + # Enable SD protection + device.sd_protect(client, Op.ENABLE) + assert client.features.sd_protection is True - # Disable SD protection - device.sd_protect(client, proto.SdProtectOperationType.DISABLE) + # Refresh SD protection + device.sd_protect(client, Op.REFRESH) + assert client.features.sd_protection is True - # Refreshing SD protection should fail - with pytest.raises(TrezorFailure): - device.sd_protect(client, proto.SdProtectOperationType.REFRESH) + # Disable SD protection + device.sd_protect(client, Op.DISABLE) + assert client.features.sd_protection is False + + # Refreshing SD protection should fail + with pytest.raises(TrezorFailure): + device.sd_protect(client, Op.REFRESH) + assert client.features.sd_protection is False + + +def test_wipe(client): + # Enable SD protection + device.sd_protect(client, Op.ENABLE) + assert client.features.sd_protection is True + + # Wipe device (this wipes internal storage) + device.wipe(client) + assert client.features.sd_protection is False + + # Restore device to working status + debuglink.load_device_by_mnemonic( + client, mnemonic=MNEMONIC12, pin=None, passphrase_protection=False, label="test" + ) + assert client.features.sd_protection is False + + # Enable SD protection + device.sd_protect(client, Op.ENABLE) + assert client.features.sd_protection is True + + # Refresh SD protection + device.sd_protect(client, Op.REFRESH)