From ff4ec2185eec7c94bbcd47fa432f0c3b5926679c Mon Sep 17 00:00:00 2001 From: matejcik Date: Tue, 14 Jul 2020 12:15:05 +0200 Subject: [PATCH] core: refactor keychain to only support one curve at a time also make a cleaner distinction between keychain, seed, path This enables using `unsafe_prompts`, because with the original code, if there was no namespace match, we wouldn't know which curve to use. For ease of implementation, we use a LRU cache for derived keys, instead of the original design "one cache entry per namespace". SLIP21 is now treated completely separately, via `slip21_namespaces` and `derive_slip21` method. If more slip21-like things come in the future, we can instead hang them on the keychain: put a per-curve Keychain object accessible by `keychain[curve_name].derive()`, and the majority usecase will just pass around `keychain[curve_name]` instead of having to specify the curve in every `derive()` call. Or alternately we'll just specify the curve in every `derive()` call, whichever seems more appropriate. --- core/src/apps/cardano/seed.py | 7 +- core/src/apps/common/keychain.py | 181 +++++++++++++++++++++++++++++++ core/src/apps/common/paths.py | 28 +++-- core/src/apps/common/seed.py | 121 +-------------------- 4 files changed, 206 insertions(+), 131 deletions(-) create mode 100644 core/src/apps/common/keychain.py diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 9c87b4599..2a4512abc 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -7,8 +7,6 @@ from apps.common import mnemonic from apps.common.passphrase import get as get_passphrase if False: - from typing import Tuple - from apps.common.seed import Bip32Path, MsgIn, MsgOut, Handler, HandlerWithKeychain @@ -18,15 +16,14 @@ class Keychain: def __init__(self, root: bip32.HDNode) -> None: self.root = root - def match_path(self, path: Bip32Path) -> Tuple[int, Bip32Path]: + def verify_path(self, path: Bip32Path) -> None: if path[: len(SEED_NAMESPACE)] != SEED_NAMESPACE: raise wire.DataError("Forbidden key path") - return 0, path[len(SEED_NAMESPACE) :] def derive(self, node_path: Bip32Path) -> bip32.HDNode: - _, suffix = self.match_path(node_path) # derive child node from the root node = self.root.clone() + suffix = node_path[len(SEED_NAMESPACE) :] for i in suffix: node.derive_cardano(i) return node diff --git a/core/src/apps/common/keychain.py b/core/src/apps/common/keychain.py new file mode 100644 index 000000000..b6f9ede3f --- /dev/null +++ b/core/src/apps/common/keychain.py @@ -0,0 +1,181 @@ +from storage import device +from trezor import wire +from trezor.crypto import bip32 + +from . import HARDENED, paths +from .seed import Slip21Node, get_seed + +if False: + from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Sequence, + TypeVar, + ) + from typing_extensions import Protocol + + from protobuf import MessageType + + T = TypeVar("T") + + class NodeProtocol(Protocol[paths.PathType]): + def derive_path(self, path: paths.PathType) -> None: + ... + + def clone(self: T) -> T: + ... + + def __del__(self) -> None: + ... + + NodeType = TypeVar("NodeType", bound=NodeProtocol) + + MsgIn = TypeVar("MsgIn", bound=MessageType) + MsgOut = TypeVar("MsgOut", bound=MessageType) + + Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]] + HandlerWithKeychain = Callable[[wire.Context, MsgIn, "Keychain"], Awaitable[MsgOut]] + + class Deletable(Protocol): + def __del__(self) -> None: + ... + + +FORBIDDEN_KEY_PATH = wire.DataError("Forbidden key path") + + +class LRUCache: + def __init__(self, size: int) -> None: + self.size = size + self.cache_keys = [] # type: List[Any] + self.cache = {} # type: Dict[Any, Deletable] + + def insert(self, key: Any, value: Deletable) -> None: + if key in self.cache_keys: + self.cache_keys.remove(key) + self.cache_keys.insert(0, key) + self.cache[key] = value + + if len(self.cache_keys) > self.size: + dropped_key = self.cache_keys.pop() + self.cache[dropped_key].__del__() + del self.cache[dropped_key] + + def get(self, key: Any) -> Any: + if key not in self.cache: + return None + + self.cache_keys.remove(key) + self.cache_keys.insert(0, key) + return self.cache[key] + + def __del__(self) -> None: + for value in self.cache.values(): + value.__del__() + self.cache.clear() + self.cache_keys.clear() + del self.cache + + +class Keychain: + def __init__( + self, + seed: bytes, + curve: str, + namespaces: Sequence[paths.Bip32Path], + slip21_namespaces: Sequence[paths.Slip21Path] = (), + ) -> None: + self.seed = seed + self.curve = curve + self.namespaces = namespaces + self.slip21_namespaces = slip21_namespaces + + self._cache = LRUCache(10) + + def __del__(self) -> None: + self._cache.__del__() + del self._cache + del self.seed + + def verify_path(self, path: paths.Bip32Path) -> None: + if "ed25519" in self.curve and not paths.path_is_hardened(path): + raise FORBIDDEN_KEY_PATH + + if device.unsafe_prompts_allowed(): + return + + if any(ns == path[: len(ns)] for ns in self.namespaces): + return + + raise FORBIDDEN_KEY_PATH + + def _derive_with_cache( + self, prefix_len: int, path: paths.PathType, new_root: Callable[[], NodeType], + ) -> NodeType: + cached_prefix = tuple(path[:prefix_len]) + cached_root = self._cache.get(cached_prefix) # type: Optional[NodeType] + if cached_root is None: + cached_root = new_root() + cached_root.derive_path(cached_prefix) + self._cache.insert(cached_prefix, cached_root) + + node = cached_root.clone() + node.derive_path(path[prefix_len:]) + return node + + def derive(self, path: paths.Bip32Path) -> bip32.HDNode: + self.verify_path(path) + return self._derive_with_cache( + prefix_len=3, + path=path, + new_root=lambda: bip32.from_seed(self.seed, self.curve), + ) + + def derive_slip21(self, path: paths.Slip21Path) -> Slip21Node: + if not device.unsafe_prompts_allowed() and not any( + ns == path[: len(ns)] for ns in self.slip21_namespaces + ): + raise FORBIDDEN_KEY_PATH + + return self._derive_with_cache( + prefix_len=1, path=path, new_root=lambda: Slip21Node(seed=self.seed), + ) + + def __enter__(self) -> "Keychain": + return self + + def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None: + self.__del__() + + +async def get_keychain( + ctx: wire.Context, + curve: str, + namespaces: Sequence[paths.Bip32Path], + slip21_namespaces: Sequence[paths.Slip21Path] = (), +) -> Keychain: + seed = await get_seed(ctx) + keychain = Keychain(seed, curve, namespaces, slip21_namespaces) + return keychain + + +def with_slip44_keychain( + slip44: int, curve: str = "secp256k1", allow_testnet: bool = False +) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: + namespaces = [[44 | HARDENED, slip44 | HARDENED]] + if allow_testnet: + namespaces.append([44 | HARDENED, 1 | HARDENED]) + + def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: + async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: + keychain = await get_keychain(ctx, curve, namespaces) + with keychain: + return await func(ctx, msg, keychain) + + return wrapper + + return decorator diff --git a/core/src/apps/common/paths.py b/core/src/apps/common/paths.py index 603037788..506a1df16 100644 --- a/core/src/apps/common/paths.py +++ b/core/src/apps/common/paths.py @@ -4,29 +4,35 @@ from trezor import ui from trezor.messages import ButtonRequestType from trezor.ui.text import Text -from apps.common import HARDENED -from apps.common.confirm import require_confirm +from . import HARDENED +from .confirm import require_confirm if False: - from typing import Any, Callable, List + from typing import Any, Callable, List, Sequence, TypeVar from trezor import wire - from apps.common import seed + + # XXX this is a circular import, but it's only for typing + from .keychain import Keychain + + Bip32Path = Sequence[int] + Slip21Path = Sequence[bytes] + PathType = TypeVar("PathType", Bip32Path, Slip21Path) async def validate_path( ctx: wire.Context, validate_func: Callable[..., bool], - keychain: seed.Keychain, + keychain: Keychain, path: List[int], curve: str, **kwargs: Any, ) -> None: - keychain.match_path(path) + keychain.verify_path(path) if not validate_func(path, **kwargs): await show_path_warning(ctx, path) -async def show_path_warning(ctx: wire.Context, path: List[int]) -> None: +async def show_path_warning(ctx: wire.Context, path: Bip32Path) -> None: text = Text("Confirm path", ui.ICON_WRONG, ui.RED) text.normal("Path") text.mono(*break_address_n_to_lines(path)) @@ -35,7 +41,7 @@ async def show_path_warning(ctx: wire.Context, path: List[int]) -> None: await require_confirm(ctx, text, ButtonRequestType.UnknownDerivationPath) -def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool: +def validate_path_for_get_public_key(path: Bip32Path, slip44_id: int) -> bool: """ Checks if path has at least three hardened items and slip44 id matches. The path is allowed to have more than three items, but all the following @@ -61,7 +67,11 @@ def is_hardened(i: int) -> bool: return bool(i & HARDENED) -def break_address_n_to_lines(address_n: list) -> list: +def path_is_hardened(address_n: Bip32Path) -> bool: + return all(is_hardened(n) for n in address_n) + + +def break_address_n_to_lines(address_n: Bip32Path) -> List[str]: def path_item(i: int) -> str: if i & HARDENED: return str(i ^ HARDENED) + "'" diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 988899dbf..debcb02f8 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -2,39 +2,11 @@ from storage import cache, device from trezor import wire from trezor.crypto import bip32, hashlib, hmac -from apps.common import HARDENED, mnemonic -from apps.common.passphrase import get as get_passphrase +from . import mnemonic +from .passphrase import get as get_passphrase if False: - from typing import ( - Any, - Awaitable, - Callable, - Dict, - List, - Sequence, - Tuple, - TypeVar, - ) - from typing_extensions import Protocol - - Bip32Path = List[int] - Slip21Path = List[bytes] - PathType = TypeVar("PathType", Bip32Path, Slip21Path) - - Namespace = Tuple[str, PathType] - - T = TypeVar("T") - - class NodeType(Protocol[PathType]): - def __del__(self) -> None: - ... - - def derive_path(self, path: PathType) -> None: - ... - - def clone(self: T) -> T: - ... + from .paths import Bip32Path, Slip21Path class Slip21Node: @@ -63,55 +35,8 @@ class Slip21Node: return Slip21Node(data=self.data) -class Keychain: - def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None: - self.seed = seed - self.namespaces = namespaces # type: Sequence[Namespace] - self.roots = {} # type: Dict[int, NodeType] - - def __del__(self) -> None: - for root in self.roots.values(): - root.__del__() - del self.roots - del self.seed - - def match_path(self, path: PathType) -> Tuple[int, PathType]: - for i, (curve, ns) in enumerate(self.namespaces): - if path[: len(ns)] == ns: - if "ed25519" in curve and not _path_hardened(path): - raise wire.DataError("Forbidden key path") - return i, path[len(ns) :] - - raise wire.DataError("Forbidden key path") - - def _new_root(self, curve: str) -> NodeType: - if curve == "slip21": - return Slip21Node(self.seed) - else: - return bip32.from_seed(self.seed, curve) - - def derive(self, path: PathType) -> NodeType: - root_index, suffix = self.match_path(path) - - if root_index not in self.roots: - curve, prefix = self.namespaces[root_index] - root = self._new_root(curve) - root.derive_path(prefix) - self.roots[root_index] = root - - node = self.roots[root_index].clone() - node.derive_path(suffix) - return node - - def __enter__(self) -> "Keychain": - return self - - def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None: - self.__del__() - - @cache.stored_async(cache.APP_COMMON_SEED) -async def _get_seed(ctx: wire.Context) -> bytes: +async def get_seed(ctx: wire.Context) -> bytes: if not device.is_initialized(): raise wire.NotInitialized("Device is not initialized") passphrase = await get_passphrase(ctx) @@ -125,12 +50,6 @@ def _get_seed_without_passphrase() -> bytes: return mnemonic.get_seed(progress_bar=False) -async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain: - seed = await _get_seed(ctx) - keychain = Keychain(seed, namespaces) - return keychain - - def derive_node_without_passphrase( path: Bip32Path, curve_name: str = "secp256k1" ) -> bip32.HDNode: @@ -150,35 +69,3 @@ def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node: def remove_ed25519_prefix(pubkey: bytes) -> bytes: # 0x01 prefix is not part of the actual public key, hence removed return pubkey[1:] - - -def _path_hardened(path: list) -> bool: - return all(i & HARDENED for i in path) - - -if False: - from protobuf import MessageType - - MsgIn = TypeVar("MsgIn", bound=MessageType) - MsgOut = TypeVar("MsgOut", bound=MessageType) - - Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]] - HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]] - - -def with_slip44_keychain( - slip44: int, curve: str = "secp256k1", allow_testnet: bool = False -) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: - namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])] - if allow_testnet: - namespaces.append((curve, [44 | HARDENED, 1 | HARDENED])) - - def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: - async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: - keychain = await get_keychain(ctx, namespaces) - with keychain: - return await func(ctx, msg, keychain) - - return wrapper - - return decorator