diff --git a/core/src/apps/common/paths.py b/core/src/apps/common/paths.py index 2cbb8017c5..603037788f 100644 --- a/core/src/apps/common/paths.py +++ b/core/src/apps/common/paths.py @@ -21,7 +21,7 @@ async def validate_path( curve: str, **kwargs: Any, ) -> None: - keychain.validate_path(path, curve) + keychain.match_path(path) if not validate_func(path, **kwargs): await show_path_warning(ctx, path) @@ -58,10 +58,7 @@ def validate_path_for_get_public_key(path: list, slip44_id: int) -> bool: def is_hardened(i: int) -> bool: - if i & HARDENED: - return True - else: - return False + return bool(i & HARDENED) def break_address_n_to_lines(address_n: list) -> list: diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index 4a022b9ffc..bcf4d2417d 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -2,26 +2,56 @@ import storage from storage import cache from trezor import wire from trezor.crypto import bip32, hashlib, hmac -from trezor.crypto.curve import secp256k1 from apps.common import HARDENED, mnemonic from apps.common.passphrase import get as get_passphrase if False: - from typing import List, Union + from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Sequence, + Tuple, + TypeVar, + ) + from typing_extensions import Protocol + + Bip32Path = List[int] + Slip21Path = List[bytes] + PathType = TypeVar("PathType", Bip32Path, Slip21Path) + + Namespace = Tuple[str, PathType] + + T = TypeVar("T") + + class NodeType(Protocol[PathType]): + def __del__(self) -> None: + ... + + def derive_path(self, path: PathType) -> None: + ... + + def clone(self: T) -> T: + ... class Slip21Node: - def __init__(self, seed: bytes = None) -> None: - if seed is not None: + def __init__(self, seed: bytes = None, data: bytes = None) -> None: + assert seed is None or data is None, "Specify exactly one of: seed, data" + if data is not None: + self.data = data + elif seed is not None: self.data = hmac.new(b"Symmetric key seed", seed, hashlib.sha512).digest() else: - self.data = b"" + raise ValueError # neither seed nor data specified def __del__(self) -> None: del self.data - def derive_path(self, path: list) -> None: + def derive_path(self, path: Slip21Path) -> None: for label in path: h = hmac.new(self.data[0:32], b"\x00", hashlib.sha512) h.update(label) @@ -31,117 +61,88 @@ class Slip21Node: return self.data[32:64] def clone(self) -> "Slip21Node": - node = Slip21Node() - node.data = self.data - return node + return Slip21Node(data=self.data) class Keychain: - """ - Keychain provides an API for deriving HD keys from previously allowed - key-spaces. - """ - - def __init__(self, seed: bytes, namespaces: list): + def __init__(self, seed: bytes, namespaces: Sequence[Namespace]) -> None: self.seed = seed - self.namespaces = namespaces - self.roots = [None] * len( - namespaces - ) # type: List[Union[bip32.HDNode, Slip21Node, None]] + self.namespaces = namespaces # type: Sequence[Namespace] + self.roots = {} # type: Dict[int, NodeType] def __del__(self) -> None: - for root in self.roots: - if root is not None and hasattr(root, "__del__"): - root.__del__() + for root in self.roots.values(): + root.__del__() del self.roots del self.seed - def validate_path(self, checked_path: list, checked_curve: str) -> None: - for curve, *path in self.namespaces: - if path == checked_path[: len(path)] and curve == checked_curve: - if "ed25519" in curve and not _path_hardened(checked_path): - break - return + def match_path(self, path: PathType) -> Tuple[int, PathType]: + for i, (curve, ns) in enumerate(self.namespaces): + if path[: len(ns)] == ns: + if "ed25519" in curve and not _path_hardened(path): + raise wire.DataError("Forbidden key path") + return i, path[len(ns) :] + raise wire.DataError("Forbidden key path") - def derive( - self, node_path: list, curve_name: str = "secp256k1" - ) -> Union[bip32.HDNode, Slip21Node]: - if "ed25519" in curve_name and not _path_hardened(node_path): - raise wire.DataError("Forbidden key path") - - # find the root node index - root_index = 0 - for curve, *path in self.namespaces: - prefix = node_path[: len(path)] - suffix = node_path[len(path) :] - if curve == curve_name and path == prefix: - break - root_index += 1 + def _new_root(self, curve: str) -> NodeType: + if curve == "slip21": + return Slip21Node(self.seed) else: - raise wire.DataError("Forbidden key path") + return bip32.from_seed(self.seed, curve) - # create the root node if not cached - root = self.roots[root_index] - if root is None: - if curve_name != "slip21": - root = bip32.from_seed(self.seed, curve_name) - else: - root = Slip21Node(self.seed) - root.derive_path(path) + def derive(self, path: PathType) -> NodeType: + root_index, suffix = self.match_path(path) + + if root_index not in self.roots: + curve, prefix = self.namespaces[root_index] + root = self._new_root(curve) + root.derive_path(prefix) self.roots[root_index] = root - # derive child node from the root - node = root.clone() + node = self.roots[root_index].clone() node.derive_path(suffix) return node - def derive_slip77_blinding_private_key(self, script: bytes) -> bytes: - """Following the derivation by Elements/Liquid.""" - master_node = self.derive(node_path=[b"SLIP-0077"], curve_name="slip21") - assert isinstance(master_node, Slip21Node) - return hmac.new( - key=master_node.key(), msg=script, digestmod=hashlib.sha256 - ).digest() + def __enter__(self) -> "Keychain": + return self - def derive_slip77_blinding_public_key(self, script: bytes) -> bytes: - private_key = self.derive_slip77_blinding_private_key(script) - return secp256k1.publickey(private_key) + def __exit__(self, exc_type: Any, exc_val: Any, tb: Any) -> None: + self.__del__() -async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain: +@cache.stored_async(cache.APP_COMMON_SEED) +async def _get_seed(ctx: wire.Context) -> bytes: if not storage.is_initialized(): raise wire.NotInitialized("Device is not initialized") - seed = cache.get(cache.APP_COMMON_SEED) - if seed is None: - passphrase = await get_passphrase(ctx) - seed = mnemonic.get_seed(passphrase) - cache.set(cache.APP_COMMON_SEED, seed) + passphrase = await get_passphrase(ctx) + return mnemonic.get_seed(passphrase) + + +@cache.stored(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) +def _get_seed_without_passphrase() -> bytes: + if not storage.is_initialized(): + raise Exception("Device is not initialized") + return mnemonic.get_seed(progress_bar=False) + + +async def get_keychain(ctx: wire.Context, namespaces: Sequence[Namespace]) -> Keychain: + seed = await _get_seed(ctx) keychain = Keychain(seed, namespaces) return keychain def derive_node_without_passphrase( - path: list, curve_name: str = "secp256k1" + path: Bip32Path, curve_name: str = "secp256k1" ) -> bip32.HDNode: - if not storage.is_initialized(): - raise Exception("Device is not initialized") - seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) - if seed is None: - seed = mnemonic.get_seed(progress_bar=False) - cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed) + seed = _get_seed_without_passphrase() node = bip32.from_seed(seed, curve_name) node.derive_path(path) return node -def derive_slip21_node_without_passphrase(path: list) -> Slip21Node: - if not storage.is_initialized(): - raise Exception("Device is not initialized") - seed = cache.get(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) - if seed is None: - seed = mnemonic.get_seed(progress_bar=False) - cache.set(cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE, seed) +def derive_slip21_node_without_passphrase(path: Slip21Path) -> Slip21Node: + seed = _get_seed_without_passphrase() node = Slip21Node(seed) node.derive_path(path) return node @@ -154,3 +155,31 @@ def remove_ed25519_prefix(pubkey: bytes) -> bytes: def _path_hardened(path: list) -> bool: return all(i & HARDENED for i in path) + + +if False: + from protobuf import MessageType + + MsgIn = TypeVar("MsgIn", bound=MessageType) + MsgOut = TypeVar("MsgOut", bound=MessageType) + + Handler = Callable[[wire.Context, MsgIn], Awaitable[MsgOut]] + HandlerWithKeychain = Callable[[wire.Context, MsgIn, Keychain], Awaitable[MsgOut]] + + +def with_slip44_keychain( + slip44: int, curve: str = "secp256k1", allow_testnet: bool = False +) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: + namespaces = [(curve, [44 | HARDENED, slip44 | HARDENED])] + if allow_testnet: + namespaces.append((curve, [44 | HARDENED, 1 | HARDENED])) + + def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: + async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: + keychain = await get_keychain(ctx, namespaces) + with keychain: + return await func(ctx, msg, keychain) + + return wrapper + + return decorator