diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index a0d96abff0..e445700989 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -1,4 +1,5 @@ import storage +from storage import cache from trezor import wire from trezor.crypto import bip32 @@ -6,11 +7,6 @@ from apps.cardano import CURVE, SEED_NAMESPACE from apps.common import mnemonic from apps.common.passphrase import get as get_passphrase -if False: - from typing import Optional - -_cached_root = None # type: Optional[bytes] - class Keychain: def __init__(self, path: list, root: bip32.HDNode): @@ -35,25 +31,30 @@ class Keychain: async def get_keychain(ctx: wire.Context) -> Keychain: - global _cached_root + root = cache.get(cache.APP_CARDANO_ROOT) if not storage.is_initialized(): raise wire.NotInitialized("Device is not initialized") - if _cached_root is None: + if root is None: passphrase = await get_passphrase(ctx) if mnemonic.is_bip39(): # derive the root node from mnemonic and passphrase - _cached_root = bip32.from_mnemonic_cardano( + root = bip32.from_mnemonic_cardano( mnemonic.get_secret().decode(), passphrase ) else: seed = mnemonic.get_seed(passphrase) - _cached_root = bip32.from_seed(seed, "ed25519 cardano seed") + root = bip32.from_seed(seed, "ed25519 cardano seed") + + storage.cache.set(cache.APP_CARDANO_ROOT, root) + + # let's not modify the one in the cache + root = root.clone() # derive the namespaced root node for i in SEED_NAMESPACE: - _cached_root.derive_cardano(i) + root.derive_cardano(i) - keychain = Keychain(SEED_NAMESPACE, _cached_root) + keychain = Keychain(SEED_NAMESPACE, root) return keychain diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 375da41700..4a022b9ffc 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -1,5 +1,5 @@ import storage -import storage.cache +from storage import cache from trezor import wire from trezor.crypto import bip32, hashlib, hmac from trezor.crypto.curve import secp256k1 @@ -112,11 +112,11 @@ class Keychain: async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain: if not storage.is_initialized(): raise wire.NotInitialized("Device is not initialized") - seed = storage.cache.get_seed() + seed = cache.get(cache.APP_COMMON_SEED) if seed is None: passphrase = await get_passphrase(ctx) seed = mnemonic.get_seed(passphrase) - storage.cache.set_seed(seed) + cache.set(cache.APP_COMMON_SEED, seed) keychain = Keychain(seed, namespaces) return keychain @@ -126,10 +126,10 @@ def derive_node_without_passphrase( ) -> bip32.HDNode: if not storage.is_initialized(): raise Exception("Device is not initialized") - seed = storage.cache.get_seed_without_passphrase() + seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) if seed is None: seed = mnemonic.get_seed(progress_bar=False) - storage.cache.set_seed_without_passphrase(seed) + cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed) node = bip32.from_seed(seed, curve_name) node.derive_path(path) return node @@ -138,10 +138,10 @@ def derive_node_without_passphrase( def derive_slip21_node_without_passphrase(path: list) -> Slip21Node: if not storage.is_initialized(): raise Exception("Device is not initialized") - seed = storage.cache.get_seed_without_passphrase() + seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) if seed is None: seed = mnemonic.get_seed(progress_bar=False) - storage.cache.set_seed_without_passphrase(seed) + cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed) node = Slip21Node(seed) node.derive_path(path) return node diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 0fac8e5df5..2bd5d9bf53 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -3,39 +3,35 @@ from trezor.crypto import random if False: from typing import Optional -_cached_seed = None # type: Optional[bytes] -_cached_seed_without_passphrase = None # type: Optional[bytes] # Needed for SLIP-21 -_cached_session_id = None # type: Optional[bytes] +APP_COMMON_SEED = 0 +APP_COMMON_SEED_WITHOUT_PASSPHRASE = 1 +APP_CARDANO_ROOT = 2 + +_cache_session_id = None # type: Optional[bytes] +_cache = {} + +if False: + from typing import Any def get_session_id() -> bytes: - global _cached_session_id - if not _cached_session_id: - _cached_session_id = random.bytes(32) - return _cached_session_id + global _cache_session_id + if not _cache_session_id: + _cache_session_id = random.bytes(32) + return _cache_session_id -def set_seed(seed: Optional[bytes]) -> None: - global _cached_seed - _cached_seed = seed +def set(key: int, value: Any) -> None: + global _cache + _cache[key] = value -def get_seed() -> Optional[bytes]: - return _cached_seed - - -def set_seed_without_passphrase(seed: Optional[bytes]) -> None: - global _cached_seed_without_passphrase - _cached_seed_without_passphrase = seed - - -def get_seed_without_passphrase() -> Optional[bytes]: - return _cached_seed_without_passphrase +def get(key: int) -> Any: + return _cache.get(key) def clear() -> None: - global _cached_session_id - _cached_session_id = None - - set_seed(None) - set_seed_without_passphrase(None) + global _cache_session_id + global _cache + _cache_session_id = None + _cache.clear() diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 4ff7e24f3a..f52a9c0a6e 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -22,6 +22,7 @@ from trezorlib.tools import parse_path XPUB_PASSPHRASE_A = "xpub6CekxGcnqnJ6osfY4Rrq7W5ogFtR54KUvz4H16XzaQuukMFZCGebEpVznfq4yFcKEmYyShwj2UKjL7CazuNSuhdkofF4mHabHkLxCMVvsqG" XPUB_PASSPHRASE_NONE = "xpub6BiVtCpG9fQPxnPmHXG8PhtzQdWC2Su4qWu6XW9tpWFYhxydCLJGrWBJZ5H6qTAHdPQ7pQhtpjiYZVZARo14qHiay2fvrX996oEP42u8wZy" +XPUB_CARDANO_PASSPHRASE_B = "d80e770f6dfc3edb58eaab68aa091b2c27b08a47583471e93437ac5f8baa61880c7af4938a941c084c19731e6e57a5710e6ad1196263291aea297ce0eec0f177" def _get_xpub(client, passphrase): @@ -222,3 +223,52 @@ def test_passphrase_missing(client): response = client.call_raw(messages.PassphraseAck(passphrase=None, on_device=False)) assert isinstance(response, messages.Failure) assert response.code == FailureType.DataError + + +@pytest.mark.skip_ui +@pytest.mark.skip_t1 +@pytest.mark.altcoin +@pytest.mark.setup_client(passphrase=True) +def test_cardano_passphrase(client): + # Cardano uses a variation of BIP-39 so we need to ask for the passphrase again. + + response = client.call_raw(messages.Initialize()) + assert isinstance(response, messages.Features) + session_id = response.session_id + assert len(session_id) == 32 + + # GetPublicKey requires passphrase and since it is not cached, + # Trezor will prompt for it. + xpub = _get_xpub(client, passphrase="A") + assert xpub == XPUB_PASSPHRASE_A + + # The passphrase is now cached for non-Cardano coins. + xpub = _get_xpub(client, passphrase=None) + assert xpub == XPUB_PASSPHRASE_A + + # Cardano will prompt for it again. + response = client.call_raw( + messages.CardanoGetPublicKey(address_n=parse_path("44'/1815'/0'/0/0")) + ) + assert isinstance(response, messages.PassphraseRequest) + response = client.call_raw(messages.PassphraseAck(passphrase="B")) + assert response.xpub == XPUB_CARDANO_PASSPHRASE_B + + # But now also Cardano has it cached. + response = client.call_raw( + messages.CardanoGetPublicKey(address_n=parse_path("44'/1815'/0'/0/0")) + ) + assert response.xpub == XPUB_CARDANO_PASSPHRASE_B + + # And others behaviour did not change. + xpub = _get_xpub(client, passphrase=None) + assert xpub == XPUB_PASSPHRASE_A + + # Initialize with the session id does not destroy the state + client.call_raw(messages.Initialize(session_id=session_id)) + xpub = _get_xpub(client, passphrase=None) + assert xpub == XPUB_PASSPHRASE_A + response = client.call_raw( + messages.CardanoGetPublicKey(address_n=parse_path("44'/1815'/0'/0/0")) + ) + assert response.xpub == XPUB_CARDANO_PASSPHRASE_B